diff --git a/CMakeLists.txt b/CMakeLists.txt index 74433f333..39d1e4307 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,8 +101,6 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX) # Copy matlab.h to the correct folder. configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h ${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY) - # Add the include directories so that matlab.h can be found - include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}") add_subdirectory(wrap) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake") diff --git a/cmake/Config.cmake.in b/cmake/Config.cmake.in index 89627a172..cc2a7df8f 100644 --- a/cmake/Config.cmake.in +++ b/cmake/Config.cmake.in @@ -21,6 +21,10 @@ else() find_dependency(Boost @BOOST_FIND_MINIMUM_VERSION@ COMPONENTS @BOOST_FIND_MINIMUM_COMPONENTS@) endif() +if(@GTSAM_USE_SYSTEM_EIGEN@) +find_dependency(Eigen3 REQUIRED) +endif() + # Load exports include(${OUR_CMAKE_DIR}/@PACKAGE_NAME@-exports.cmake) diff --git a/cmake/FindEigen3.cmake b/cmake/FindEigen3.cmake deleted file mode 100644 index 9c546a05d..000000000 --- a/cmake/FindEigen3.cmake +++ /dev/null @@ -1,81 +0,0 @@ -# - Try to find Eigen3 lib -# -# This module supports requiring a minimum version, e.g. you can do -# find_package(Eigen3 3.1.2) -# to require version 3.1.2 or newer of Eigen3. -# -# Once done this will define -# -# EIGEN3_FOUND - system has eigen lib with correct version -# EIGEN3_INCLUDE_DIR - the eigen include directory -# EIGEN3_VERSION - eigen version - -# Copyright (c) 2006, 2007 Montel Laurent, -# Copyright (c) 2008, 2009 Gael Guennebaud, -# Copyright (c) 2009 Benoit Jacob -# Redistribution and use is allowed according to the terms of the 2-clause BSD license. - -if(NOT Eigen3_FIND_VERSION) - if(NOT Eigen3_FIND_VERSION_MAJOR) - set(Eigen3_FIND_VERSION_MAJOR 2) - endif(NOT Eigen3_FIND_VERSION_MAJOR) - if(NOT Eigen3_FIND_VERSION_MINOR) - set(Eigen3_FIND_VERSION_MINOR 91) - endif(NOT Eigen3_FIND_VERSION_MINOR) - if(NOT Eigen3_FIND_VERSION_PATCH) - set(Eigen3_FIND_VERSION_PATCH 0) - endif(NOT Eigen3_FIND_VERSION_PATCH) - - set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}") -endif(NOT Eigen3_FIND_VERSION) - -macro(_eigen3_check_version) - file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header) - - string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}") - set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}") - string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}") - set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}") - string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}") - set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}") - - set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION}) - if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) - set(EIGEN3_VERSION_OK FALSE) - else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) - set(EIGEN3_VERSION_OK TRUE) - endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) - - if(NOT EIGEN3_VERSION_OK) - - message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, " - "but at least version ${Eigen3_FIND_VERSION} is required") - endif(NOT EIGEN3_VERSION_OK) -endmacro(_eigen3_check_version) - -if (EIGEN3_INCLUDE_DIR) - - # in cache already - _eigen3_check_version() - set(EIGEN3_FOUND ${EIGEN3_VERSION_OK}) - -else (EIGEN3_INCLUDE_DIR) - - find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library - PATHS - ${CMAKE_INSTALL_PREFIX}/include - ${KDE4_INCLUDE_DIR} - PATH_SUFFIXES eigen3 eigen - ) - - if(EIGEN3_INCLUDE_DIR) - _eigen3_check_version() - endif(EIGEN3_INCLUDE_DIR) - - include(FindPackageHandleStandardArgs) - find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK) - - mark_as_advanced(EIGEN3_INCLUDE_DIR) - -endif(EIGEN3_INCLUDE_DIR) - diff --git a/cmake/HandleEigen.cmake b/cmake/HandleEigen.cmake index c49eb4f8e..48941b85b 100644 --- a/cmake/HandleEigen.cmake +++ b/cmake/HandleEigen.cmake @@ -1,7 +1,7 @@ ############################################################################### # Option for using system Eigen or GTSAM-bundled Eigen # Default: Use system's Eigen if found automatically: -find_package(Eigen3 QUIET) +find_package(Eigen3 CONFIG QUIET) set(USE_SYSTEM_EIGEN_INITIAL_VALUE ${Eigen3_FOUND}) option(GTSAM_USE_SYSTEM_EIGEN "Find and use system-installed Eigen. If 'off', use the one bundled with GTSAM" ${USE_SYSTEM_EIGEN_INITIAL_VALUE}) unset(USE_SYSTEM_EIGEN_INITIAL_VALUE) @@ -14,10 +14,14 @@ endif() # Switch for using system Eigen or GTSAM-bundled Eigen if(GTSAM_USE_SYSTEM_EIGEN) - find_package(Eigen3 REQUIRED) # need to find again as REQUIRED + # Since Eigen 3.3.0 a Eigen3Config.cmake is available so use it. + find_package(Eigen3 CONFIG REQUIRED) # need to find again as REQUIRED - # Use generic Eigen include paths e.g. - set(GTSAM_EIGEN_INCLUDE_FOR_INSTALL "${EIGEN3_INCLUDE_DIR}") + # The actual include directory (for BUILD cmake target interface): + # Note: EIGEN3_INCLUDE_DIR points to some random location on some eigen + # versions. So here I use the target itself to get the proper include + # directory (it is generated by cmake, thus has the correct path) + get_target_property(GTSAM_EIGEN_INCLUDE_FOR_BUILD Eigen3::Eigen INTERFACE_INCLUDE_DIRECTORIES) # check if MKL is also enabled - can have one or the other, but not both! # Note: Eigen >= v3.2.5 includes our patches @@ -30,9 +34,6 @@ if(GTSAM_USE_SYSTEM_EIGEN) if(EIGEN_USE_MKL_ALL AND (EIGEN3_VERSION VERSION_EQUAL 3.3.4)) message(FATAL_ERROR "MKL does not work with Eigen 3.3.4 because of a bug in Eigen. See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1527. Disable GTSAM_USE_SYSTEM_EIGEN to use GTSAM's copy of Eigen, disable GTSAM_WITH_EIGEN_MKL, or upgrade/patch your installation of Eigen.") endif() - - # The actual include directory (for BUILD cmake target interface): - set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${EIGEN3_INCLUDE_DIR}") else() # Use bundled Eigen include path. # Clear any variables set by FindEigen3 @@ -46,6 +47,19 @@ else() # The actual include directory (for BUILD cmake target interface): set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${GTSAM_SOURCE_DIR}/gtsam/3rdparty/Eigen/") + + add_library(gtsam_eigen3 INTERFACE) + + target_include_directories(gtsam_eigen3 INTERFACE + $ + $ + ) + add_library(Eigen3::Eigen ALIAS gtsam_eigen3) + + install(TARGETS gtsam_eigen3 EXPORT GTSAM-exports PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + + list(APPEND GTSAM_EXPORTED_TARGETS gtsam_eigen3) + set(GTSAM_EXPORTED_TARGETS "${GTSAM_EXPORTED_TARGETS}" PARENT_SCOPE) endif() # Detect Eigen version: diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index 09f1ea806..d3408ee7f 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -117,12 +117,9 @@ set_target_properties(gtsam PROPERTIES VERSION ${gtsam_version} SOVERSION ${gtsam_soversion}) -# Append Eigen include path, set in top-level CMakeLists.txt to either # system-eigen, or GTSAM eigen path -target_include_directories(gtsam PUBLIC - $ - $ -) +target_link_libraries(gtsam PUBLIC Eigen3::Eigen) + # MKL include dir: if (GTSAM_USE_EIGEN_MKL) target_include_directories(gtsam PUBLIC ${MKL_INCLUDE_DIR}) diff --git a/gtsam/base/treeTraversal-inst.h b/gtsam/base/treeTraversal-inst.h index 30cec3b9a..be45a248e 100644 --- a/gtsam/base/treeTraversal-inst.h +++ b/gtsam/base/treeTraversal-inst.h @@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str, PrintForestVisitorPre visitor(keyFormatter); DepthFirstForest(forest, str, visitor); } -} +} // namespace treeTraversal -} +} // namespace gtsam diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 674c625ce..0ea84e450 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -11,15 +11,17 @@ /** * @file Assignment.h - * @brief An assignment from labels to a discrete value index (size_t) + * @brief An assignment from labels to a discrete value index (size_t) * @author Frank Dellaert * @date Feb 5, 2012 */ #pragma once +#include #include #include +#include #include #include @@ -33,13 +35,30 @@ namespace gtsam { */ template class Assignment : public std::map { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when + * printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } + public: using std::map::operator=; - void print(const std::string& s = "Assignment: ") const { + void print(const std::string& s = "Assignment: ", + const std::function& labelFormatter = + &DefaultFormatter) const { std::cout << s << ": "; - for (const typename Assignment::value_type& keyValue : *this) - std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + for (const typename Assignment::value_type& keyValue : *this) { + std::cout << "(" << labelFormatter(keyValue.first) << ", " + << keyValue.second << ")"; + } std::cout << std::endl; } diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 121d61103..06ed2ca3b 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -48,4 +48,25 @@ namespace gtsam { return keys & key2; } + void DiscreteKeys::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + for (auto&& dkey : *this) { + std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second + << std::endl; + } + } + + bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const { + if (this->size() != other.size()) { + return false; + } + + for (size_t i = 0; i < this->size(); i++) { + if (this->at(i).first != other.at(i).first || + this->at(i).second != other.at(i).second) { + return false; + } + } + return true; + } } diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 40343d21f..fe348ee62 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -70,8 +71,30 @@ namespace gtsam { push_back(key); return *this; } + + /// Print the keys and cardinalities. + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// Check equality to another DiscreteKeys object. + bool equals(const DiscreteKeys& other, double tol = 0) const; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "DiscreteKeys", + boost::serialization::base_object>(*this)); + } + }; // DiscreteKeys /// Create a list from two keys GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); -} + + // traits + template <> + struct traits : public Testable {}; + + } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 6635633a2..ab69e82d7 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) { clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); + DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9); + DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9); + DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9); + // check separator marginal P(S9), should be P(14) clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = diff --git a/gtsam/discrete/tests/testDiscreteFactor.cpp b/gtsam/discrete/tests/testDiscreteFactor.cpp index 8681cf7eb..db0491c9d 100644 --- a/gtsam/discrete/tests/testDiscreteFactor.cpp +++ b/gtsam/discrete/tests/testDiscreteFactor.cpp @@ -16,14 +16,29 @@ * @author Duy-Nguyen Ta */ -#include -#include #include +#include +#include +#include + #include using namespace boost::assign; using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + +/* ************************************************************************* */ +TEST(DisreteKeys, Serialization) { + DiscreteKeys keys; + keys& DiscreteKey(0, 2); + keys& DiscreteKey(1, 3); + keys& DiscreteKey(2, 4); + + EXPECT(equalsObj(keys)); + EXPECT(equalsXML(keys)); + EXPECT(equalsBinary(keys)); +} /* ************************************************************************* */ int main() { @@ -31,4 +46,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam/geometry/tests/testSimilarity2.cpp b/gtsam/geometry/tests/testSimilarity2.cpp index dd4fd0efd..ca041fc7b 100644 --- a/gtsam/geometry/tests/testSimilarity2.cpp +++ b/gtsam/geometry/tests/testSimilarity2.cpp @@ -33,8 +33,6 @@ static const Point2 P(0.2, 0.7); static const Rot2 R = Rot2::fromAngle(0.3); static const double s = 4; -const double degree = M_PI / 180; - //****************************************************************************** TEST(Similarity2, Concepts) { BOOST_CONCEPT_ASSERT((IsGroup)); diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 2671f0ef7..00b4d05f8 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -66,6 +66,27 @@ class KeySet { void serialize() const; }; +// Actually a vector, needed for Matlab +class KeyVector { + KeyVector(); + KeyVector(const gtsam::KeyVector& other); + + // Note: no print function + + // common STL methods + size_t size() const; + bool empty() const; + void clear(); + + // structure specific methods + size_t at(size_t i) const; + size_t front() const; + size_t back() const; + void push_back(size_t key) const; + + void serialize() const; +}; + // Actually a FastMap class KeyGroupMap { KeyGroupMap(); diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 6816dfbf6..5172a9798 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -119,33 +119,90 @@ void GaussianMixture::print(const std::string &s, "", [&](Key k) { return formatter(k); }, [&](const GaussianConditional::shared_ptr &gf) -> std::string { RedirectCout rd; - if (gf && !gf->empty()) + if (gf && !gf->empty()) { gf->print("", formatter); - else - return {"nullptr"}; - return rd.str(); + return rd.str(); + } else { + return "nullptr"; + } }); } -/* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { - // Functional which loops over all assignments and create a set of - // GaussianConditionals - auto pruner = [&decisionTree]( +/* ************************************************************************* */ +/// Return the DiscreteKey vector as a set. +std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { + std::set s; + s.insert(dkeys.begin(), dkeys.end()); + return s; +} + +/* ************************************************************************* */ +/** + * @brief Helper function to get the pruner functional. + * + * @param decisionTree The probability decision tree of only discrete keys. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ +std::function &, const GaussianConditional::shared_ptr &)> +GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { + // Get the discrete keys as sets for the decision tree + // and the gaussian mixture. + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); + + auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value DiscreteValues values(choices); - if (decisionTree(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; + // Case where the gaussian mixture has the same + // discrete keys as the decision tree. + if (gaussianMixtureKeySet == decisionTreeKeySet) { + if (decisionTree(values) == 0.0) { + // empty aka null pointer + boost::shared_ptr null; + return null; + } else { + return conditional; + } } else { - return conditional; + std::vector set_diff; + std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), + gaussianMixtureKeySet.begin(), + gaussianMixtureKeySet.end(), + std::back_inserter(set_diff)); + + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment.begin(), assignment.end()); + + // If any one of the sub-branches are non-zero, + // we need this conditional. + if (decisionTree(augmented_values) > 0.0) { + return conditional; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return nullptr; } }; + return pruner; +} + +/* *******************************************************************************/ +void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); + // Functional which loops over all assignments and create a set of + // GaussianConditionals + auto pruner = prunerFunc(decisionTree); auto pruned_conditionals = conditionals_.apply(pruner); conditionals_.root_ = pruned_conditionals.root_; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 6d638ea74..d2276f35e 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -70,6 +70,17 @@ class GTSAM_EXPORT GaussianMixture */ Sum asGaussianFactorGraphTree() const; + /** + * @brief Helper function to get the pruner functor. + * + * @param decisionTree The pruned discrete probability decision tree. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ + std::function &, const GaussianConditional::shared_ptr &)> + prunerFunc(const DecisionTreeFactor &decisionTree); + public: /// @name Constructors /// @{ diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 9b5be188a..181b1e6a5 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s, [&](const GaussianFactor::shared_ptr &gf) -> std::string { RedirectCout rd; std::cout << ":\n"; - if (gf) + if (gf && !gf->empty()) { gf->print("", formatter); - else - return {"nullptr"}; - return rd.str(); + return rd.str(); + } else { + return "nullptr"; + } }); std::cout << "}" << std::endl; } diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 4665a3136..cc27600f0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -15,23 +15,40 @@ * @date January 2022 */ +#include +#include #include #include -#include namespace gtsam { /* ************************************************************************* */ -/// Return the DiscreteKey vector as a set. -static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { - std::set s; - s.insert(dkeys.begin(), dkeys.end()); - return s; +DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { + AlgebraicDecisionTree decisionTree; + + // The canonical decision tree factor which will get the discrete conditionals + // added to it. + DecisionTreeFactor dtFactor; + + for (size_t i = 0; i < this->size(); i++) { + HybridConditional::shared_ptr conditional = this->at(i); + if (conditional->isDiscrete()) { + // Convert to a DecisionTreeFactor and add it to the main factor. + DecisionTreeFactor f(*conditional->asDiscreteConditional()); + dtFactor = dtFactor * f; + } + } + return boost::make_shared(dtFactor); } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + // Get the decision tree of only the discrete keys + auto discreteConditionals = this->discreteConditionals(); + const DecisionTreeFactor::shared_ptr discreteFactor = + boost::make_shared( + discreteConditionals->prune(maxNrLeaves)); + /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. @@ -41,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune( HybridBayesNet prunedBayesNetFragment; - // Functional which loops over all assignments and create a set of - // GaussianConditionals - auto pruner = [&](const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - DiscreteValues values(choices); - - if ((*discreteFactor)(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; - } else { - return conditional; - } - }; - // Go through all the conditionals in the // Bayes Net and prune them as per discreteFactor. for (size_t i = 0; i < this->size(); i++) { HybridConditional::shared_ptr conditional = this->at(i); - GaussianMixture::shared_ptr gaussianMixture = - boost::dynamic_pointer_cast(conditional->inner()); + if (conditional->isHybrid()) { + GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); - if (gaussianMixture) { - // We may have mixtures with less discrete keys than discreteFactor so we - // skip those since the label assignment does not exist. - auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); - auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); - if (gmKeySet != dfKeySet) { - // Add the gaussianMixture which doesn't have to be pruned. - prunedBayesNetFragment.push_back( - boost::make_shared(gaussianMixture)); - continue; - } - - // Run the pruning to get a new, pruned tree - GaussianMixture::Conditionals prunedTree = - gaussianMixture->conditionals().apply(pruner); - - DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); - // reverse keys to get a natural ordering - std::reverse(discreteKeys.begin(), discreteKeys.end()); - - // Convert from boost::iterator_range to KeyVector - // so we can pass it to constructor. - KeyVector frontals(gaussianMixture->frontals().begin(), - gaussianMixture->frontals().end()), - parents(gaussianMixture->parents().begin(), - gaussianMixture->parents().end()); - - // Create the new gaussian mixture and add it to the bayes net. - auto prunedGaussianMixture = boost::make_shared( - frontals, parents, discreteKeys, prunedTree); + // Make a copy of the gaussian mixture and prune it! + auto prunedGaussianMixture = + boost::make_shared(*gaussianMixture); + prunedGaussianMixture->prune(*discreteFactor); // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back( @@ -111,14 +85,18 @@ HybridBayesNet HybridBayesNet::prune( } /* ************************************************************************* */ -GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { - return boost::dynamic_pointer_cast(factors_.at(i)->inner()); +GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const { + return factors_.at(i)->asMixture(); +} + +/* ************************************************************************* */ +GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { + return factors_.at(i)->asGaussian(); } /* ************************************************************************* */ DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { - return boost::dynamic_pointer_cast( - factors_.at(i)->inner()); + return factors_.at(i)->asDiscreteConditional(); } /* ************************************************************************* */ @@ -126,16 +104,45 @@ GaussianBayesNet HybridBayesNet::choose( const DiscreteValues &assignment) const { GaussianBayesNet gbn; for (size_t idx = 0; idx < size(); idx++) { - GaussianMixture gm = *this->atGaussian(idx); - gbn.push_back(gm(assignment)); + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixture gm = *this->atMixture(idx); + gbn.push_back(gm(assignment)); + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, add gaussian conditional. + gbn.push_back((this->atGaussian(idx))); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we simply continue. + continue; + } } + return gbn; } -/* *******************************************************************************/ +/* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { - auto dag = HybridLookupDAG::FromBayesNet(*this); - return dag.argmax(); + // Solve for the MPE + DiscreteBayesNet discrete_bn; + for (auto &conditional : factors_) { + if (conditional->isDiscrete()) { + discrete_bn.push_back(conditional->asDiscreteConditional()); + } + } + + DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); + + // Given the MPE, compute the optimal continuous values. + GaussianBayesNet gbn = this->choose(mpe); + return HybridValues(mpe, gbn.optimize()); +} + +/* ************************************************************************* */ +VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { + GaussianBayesNet gbn = this->choose(assignment); + return gbn.optimize(); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 152128122..dea7108fe 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -39,12 +40,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using shared_ptr = boost::shared_ptr; using sharedConditional = boost::shared_ptr; + /// @name Standard Constructors + /// @{ + /** Construct empty bayes net */ HybridBayesNet() = default; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + /// @} + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This &bn, double tol = 1e-9) const { + return Base::equals(bn, tol); + } + + /// print graph + void print( + const std::string &s = "", + const KeyFormatter &formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + + /// @} + /// @name Standard Interface + /// @{ /// Add HybridConditional to Bayes Net using Base::add; @@ -55,8 +75,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { HybridConditional(boost::make_shared(key, table))); } + using Base::push_back; + /// Get a specific Gaussian mixture by index `i`. - GaussianMixture::shared_ptr atGaussian(size_t i) const; + GaussianMixture::shared_ptr atMixture(size_t i) const; + + /// Get a specific Gaussian conditional by index `i`. + GaussianConditional::shared_ptr atGaussian(size_t i) const; /// Get a specific discrete conditional by index `i`. DiscreteConditional::shared_ptr atDiscrete(size_t i) const; @@ -70,10 +95,49 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ GaussianBayesNet choose(const DiscreteValues &assignment) const; - /// Solve the HybridBayesNet by back-substitution. - /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and - /// put this method there? + /** + * @brief Solve the HybridBayesNet by first computing the MPE of all the + * discrete variables and then optimizing the continuous variables based on + * the MPE assignment. + * + * @return HybridValues + */ HybridValues optimize() const; + + /** + * @brief Given the discrete assignment, return the optimized estimate for the + * selected Gaussian BayesNet. + * + * @param assignment An assignment of discrete values. + * @return Values + */ + VectorValues optimize(const DiscreteValues &assignment) const; + + protected: + /** + * @brief Get all the discrete conditionals as a decision tree factor. + * + * @return DecisionTreeFactor::shared_ptr + */ + DecisionTreeFactor::shared_ptr discreteConditionals() const; + + public: + /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. + HybridBayesNet prune(size_t maxNrLeaves) const; + + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index d65270f91..266b295dd 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -18,10 +18,13 @@ */ #include +#include +#include #include #include #include #include +#include namespace gtsam { @@ -35,4 +38,161 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +HybridValues HybridBayesTree::optimize() const { + DiscreteBayesNet dbn; + DiscreteValues mpe; + + auto root = roots_.at(0); + // Access the clique and get the underlying hybrid conditional + HybridConditional::shared_ptr root_conditional = root->conditional(); + + // The root should be discrete only, we compute the MPE + if (root_conditional->isDiscrete()) { + dbn.push_back(root_conditional->asDiscreteConditional()); + mpe = DiscreteFactorGraph(dbn).optimize(); + } else { + throw std::runtime_error( + "HybridBayesTree root is not discrete-only. Please check elimination " + "ordering or use continuous factor graph."); + } + + VectorValues values = optimize(mpe); + return HybridValues(mpe, values); +} + +/* ************************************************************************* */ +/** + * @brief Helper class for Depth First Forest traversal on the HybridBayesTree. + * + * When traversing the tree, the pre-order visitor will receive an instance of + * this class with the parent clique data. + */ +struct HybridAssignmentData { + const DiscreteValues assignment_; + GaussianBayesTree::sharedNode parentClique_; + // The gaussian bayes tree that will be recursively created. + GaussianBayesTree* gaussianbayesTree_; + + /** + * @brief Construct a new Hybrid Assignment Data object. + * + * @param assignment The MPE assignment for the optimal Gaussian cliques. + * @param parentClique The clique from the parent node of the current node. + * @param gbt The Gaussian Bayes Tree being generated during tree traversal. + */ + HybridAssignmentData(const DiscreteValues& assignment, + const GaussianBayesTree::sharedNode& parentClique, + GaussianBayesTree* gbt) + : assignment_(assignment), + parentClique_(parentClique), + gaussianbayesTree_(gbt) {} + + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The HybridAssignmentData from the parent node. + * @return HybridAssignmentData which is passed to the children. + */ + static HybridAssignmentData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& node, + HybridAssignmentData& parentData) { + // Extract the gaussian conditional from the Hybrid clique + HybridConditional::shared_ptr hybrid_conditional = node->conditional(); + GaussianConditional::shared_ptr conditional; + if (hybrid_conditional->isHybrid()) { + conditional = (*hybrid_conditional->asMixture())(parentData.assignment_); + } else if (hybrid_conditional->isContinuous()) { + conditional = hybrid_conditional->asGaussian(); + } else { + // Discrete only conditional, so we set to empty gaussian conditional + conditional = boost::make_shared(); + } + + // Create the GaussianClique for the current node + auto clique = boost::make_shared(conditional); + // Add the current clique to the GaussianBayesTree. + parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_); + + // Create new HybridAssignmentData where the current node is the parent + // This will be passed down to the children nodes + HybridAssignmentData data(parentData.assignment_, clique, + parentData.gaussianbayesTree_); + return data; + } +}; + +/* ************************************************************************* + */ +VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { + GaussianBayesTree gbt; + HybridAssignmentData rootData(assignment, 0, &gbt); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor, + visitorPost); + } + + VectorValues result = gbt.optimize(); + + // Return the optimized bayes net result. + return result; +} + +/* ************************************************************************* */ +void HybridBayesTree::prune(const size_t maxNrLeaves) { + auto decisionTree = boost::dynamic_pointer_cast( + this->roots_.at(0)->conditional()->inner()); + + DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); + decisionTree->root_ = prunedDecisionTree.root_; + + /// Helper struct for pruning the hybrid bayes tree. + struct HybridPrunerData { + /// The discrete decision tree after pruning. + DecisionTreeFactor prunedDecisionTree; + HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, + const HybridBayesTree::sharedNode& parentClique) + : prunedDecisionTree(prunedDecisionTree) {} + + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The data from the parent node. + * @return HybridPrunerData which is passed to the children. + */ + static HybridPrunerData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& clique, + HybridPrunerData& parentData) { + // Get the conditional + HybridConditional::shared_ptr conditional = clique->conditional(); + + // If conditional is hybrid, we prune it. + if (conditional->isHybrid()) { + auto gaussianMixture = conditional->asMixture(); + + gaussianMixture->prune(parentData.prunedDecisionTree); + } + return parentData; + } + }; + + HybridPrunerData rootData(prunedDecisionTree, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); + } +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 165f20a9f..8af0af968 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -73,9 +73,46 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; + /** + * @brief Optimize the hybrid Bayes tree by computing the MPE for the current + * set of discrete variables and using it to compute the best continuous + * update delta. + * + * @return HybridValues + */ + HybridValues optimize() const; + + /** + * @brief Recursively optimize the BayesTree to produce a vector solution. + * + * @param assignment The discrete values assignment to select the Gaussian + * mixtures. + * @return VectorValues + */ + VectorValues optimize(const DiscreteValues& assignment) const; + + /** + * @brief Prune the underlying Bayes tree. + * + * @param maxNumberLeaves The max number of leaf nodes to keep. + */ + void prune(const size_t maxNumberLeaves); + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + /** * @brief Class for Hybrid Bayes tree orphan subtrees. * diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 0b82ccb2f..050f10290 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -34,8 +34,6 @@ namespace gtsam { -class HybridGaussianFactorGraph; - /** * Hybrid Conditional Density * @@ -71,7 +69,7 @@ class GTSAM_EXPORT HybridConditional BaseConditional; ///< Typedef to our conditional base class protected: - // Type-erased pointer to the inner type + /// Type-erased pointer to the inner type boost::shared_ptr inner_; public: @@ -129,8 +127,7 @@ class GTSAM_EXPORT HybridConditional * @param gaussianMixture Gaussian Mixture Conditional used to create the * HybridConditional. */ - HybridConditional( - boost::shared_ptr gaussianMixture); + HybridConditional(boost::shared_ptr gaussianMixture); /** * @brief Return HybridConditional as a GaussianMixture @@ -142,6 +139,17 @@ class GTSAM_EXPORT HybridConditional return boost::static_pointer_cast(inner_); } + /** + * @brief Return HybridConditional as a GaussianConditional + * + * @return GaussianConditional::shared_ptr + */ + GaussianConditional::shared_ptr asGaussian() { + if (!isContinuous()) + throw std::invalid_argument("Not a continuous conditional"); + return boost::static_pointer_cast(inner_); + } + /** * @brief Return conditional as a DiscreteConditional * @@ -170,10 +178,19 @@ class GTSAM_EXPORT HybridConditional /// Get the type-erased pointer to the inner type boost::shared_ptr inner() { return inner_; } -}; // DiscreteConditional + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } + +}; // HybridConditional // traits template <> -struct traits : public Testable {}; +struct traits : public Testable {}; } // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index a9fe62cf1..1216fd922 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,10 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), - isContinuous_(true), - nrContinuous_(keys.size()), - continuousKeys_(keys) {} + : Base(keys), isContinuous_(true), continuousKeys_(keys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &continuousKeys, @@ -62,7 +59,6 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys, isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), - nrContinuous_(continuousKeys.size()), discreteKeys_(discreteKeys), continuousKeys_(continuousKeys) {} @@ -103,7 +99,6 @@ void HybridFactor::print(const std::string &s, if (d < discreteKeys_.size() - 1) { std::cout << " "; } - } std::cout << "]"; } diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 138955f15..e0cae55c1 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -49,8 +49,6 @@ class GTSAM_EXPORT HybridFactor : public Factor { bool isContinuous_ = false; bool isHybrid_ = false; - size_t nrContinuous_ = 0; - protected: // Set of DiscreteKeys for this factor. DiscreteKeys discreteKeys_; @@ -131,6 +129,19 @@ class GTSAM_EXPORT HybridFactor : public Factor { const KeyVector &continuousKeys() const { return continuousKeys_; } /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(isDiscrete_); + ar &BOOST_SERIALIZATION_NVP(isContinuous_); + ar &BOOST_SERIALIZATION_NVP(isHybrid_); + ar &BOOST_SERIALIZATION_NVP(discreteKeys_); + ar &BOOST_SERIALIZATION_NVP(continuousKeys_); + } }; // HybridFactor diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index fc730f0c9..05a17b000 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph { push_hybrid(p); } } + + /// Get all the discrete keys in the factor graph. + const KeySet discreteKeys() const { + KeySet discrete_keys; + for (auto& factor : factors_) { + for (const DiscreteKey& k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + return discrete_keys; + } + + /// Get all the continuous keys in the factor graph. + const KeySet continuousKeys() const { + KeySet keys; + for (auto& factor : factors_) { + for (const Key& key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c024c1255..041603fbd 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals( } } else if (f->isContinuous()) { - deferredFactors.push_back( - boost::dynamic_pointer_cast(f)->inner()); + if (auto gf = boost::dynamic_pointer_cast(f)) { + deferredFactors.push_back(gf->inner()); + } + if (auto cg = boost::dynamic_pointer_cast(f)) { + deferredFactors.push_back(cg->asGaussian()); + } } else if (f->isDiscrete()) { // Don't do anything for discrete-only factors @@ -135,9 +139,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors, for (auto &fp : factors) { if (auto ptr = boost::dynamic_pointer_cast(fp)) { gfg.push_back(ptr->inner()); - } else if (auto p = - boost::static_pointer_cast(fp)->inner()) { - gfg.push_back(boost::static_pointer_cast(p)); + } else if (auto ptr = boost::static_pointer_cast(fp)) { + gfg.push_back( + boost::static_pointer_cast(ptr->inner())); } else { // It is an orphan wrapped conditional } @@ -153,12 +157,14 @@ std::pair discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { DiscreteFactorGraph dfg; - for (auto &fp : factors) { - if (auto ptr = boost::dynamic_pointer_cast(fp)) { - dfg.push_back(ptr->inner()); - } else if (auto p = - boost::static_pointer_cast(fp)->inner()) { - dfg.push_back(boost::static_pointer_cast(p)); + + for (auto &factor : factors) { + if (auto p = boost::dynamic_pointer_cast(factor)) { + dfg.push_back(p->inner()); + } else if (auto p = boost::static_pointer_cast(factor)) { + auto discrete_conditional = + boost::static_pointer_cast(p->inner()); + dfg.push_back(discrete_conditional); } else { // It is an orphan wrapper } @@ -213,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors, result = EliminatePreferCholesky(graph, frontalKeys); if (keysOfEliminated.empty()) { - keysOfEliminated = - result.first->keys(); // Initialize the keysOfEliminated to be the + // Initialize the keysOfEliminated to be the keys of the + // eliminated GaussianConditional + keysOfEliminated = result.first->keys(); } - // keysOfEliminated of the GaussianConditional if (keysOfSeparator.empty()) { keysOfSeparator = result.second->keys(); } @@ -244,6 +250,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, return exp(-factor->error(empty_values)); }; DecisionTree fdt(separatorFactors, factorError); + auto discreteFactor = boost::make_shared(discreteSeparator, fdt); @@ -401,4 +408,19 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } +/* ************************************************************************ */ +const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { + KeySet discrete_keys = discreteKeys(); + for (auto &factor : factors_) { + for (const DiscreteKey &k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + + const VariableIndex index(factors_); + Ordering ordering = Ordering::ColamdConstrainedLast( + index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); + return ordering; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index f12c93e8d..6a0362500 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -169,6 +169,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph Base::push_back(sharedFactor); } } + + /** + * @brief Return a Colamd constrained ordering where the discrete keys are + * eliminated after the continuous keys. + * + * @return const Ordering + */ + const Ordering getHybridOrdering() const; }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 23a95c021..de87dd92f 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -14,9 +14,10 @@ * @date March 31, 2022 * @author Fan Jiang * @author Frank Dellaert - * @author Richard Roberts + * @author Varun Agrawal */ +#include #include #include #include @@ -41,6 +42,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree) void HybridGaussianISAM::updateInternal( const HybridGaussianFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const boost::optional& maxNrLeaves, const boost::optional& ordering, const HybridBayesTree::Eliminate& function) { // Remove the contaminated part of the Bayes tree @@ -57,26 +59,28 @@ void HybridGaussianISAM::updateInternal( factors += newFactors; // Add the orphaned subtrees - for (const sharedClique& orphan : *orphans) - factors += boost::make_shared >(orphan); - - KeySet allDiscrete; - for (auto& factor : factors) { - for (auto& k : factor->discreteKeys()) { - allDiscrete.insert(k.first); - } + for (const sharedClique& orphan : *orphans) { + factors += boost::make_shared>(orphan); } + + // Get all the discrete keys from the factors + KeySet allDiscrete = factors.discreteKeys(); + + // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; + // Insert continuous keys first. for (auto& k : newFactorKeys) { if (!allDiscrete.exists(k)) { newKeysDiscreteLast.push_back(k); } } + // Insert discrete keys at the end std::copy(allDiscrete.begin(), allDiscrete.end(), std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last const VariableIndex index(factors); + Ordering elimination_ordering; if (ordering) { elimination_ordering = *ordering; @@ -91,6 +95,10 @@ void HybridGaussianISAM::updateInternal( HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); + if (maxNrLeaves) { + bayesTree->prune(*maxNrLeaves); + } + // Re-add into Bayes tree data structures this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(), bayesTree->roots().end()); @@ -99,61 +107,11 @@ void HybridGaussianISAM::updateInternal( /* ************************************************************************* */ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, + const boost::optional& maxNrLeaves, const boost::optional& ordering, const HybridBayesTree::Eliminate& function) { Cliques orphans; - this->updateInternal(newFactors, &orphans, ordering, function); -} - -/* ************************************************************************* */ -/** - * @brief Check if `b` is a subset of `a`. - * Non-const since they need to be sorted. - * - * @param a KeyVector - * @param b KeyVector - * @return True if the keys of b is a subset of a, else false. - */ -bool IsSubset(KeyVector a, KeyVector b) { - std::sort(a.begin(), a.end()); - std::sort(b.begin(), b.end()); - return std::includes(a.begin(), a.end(), b.begin(), b.end()); -} - -/* ************************************************************************* */ -void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) { - auto decisionTree = boost::dynamic_pointer_cast( - this->clique(root)->conditional()->inner()); - DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDiscreteFactor.root_; - - std::vector prunedKeys; - for (auto&& clique : nodes()) { - // The cliques can be repeated for each frontal so we record it in - // prunedKeys and check if we have already pruned a particular clique. - if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) != - prunedKeys.end()) { - continue; - } - - // Add all the keys of the current clique to be pruned to prunedKeys - for (auto&& key : clique.second->conditional()->frontals()) { - prunedKeys.push_back(key); - } - - // Convert parents() to a KeyVector for comparison - KeyVector parents; - for (auto&& parent : clique.second->conditional()->parents()) { - parents.push_back(parent); - } - - if (IsSubset(parents, decisionTree->keys())) { - auto gaussianMixture = boost::dynamic_pointer_cast( - clique.second->conditional()->inner()); - - gaussianMixture->prune(prunedDiscreteFactor); - } - } + this->updateInternal(newFactors, &orphans, maxNrLeaves, ordering, function); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianISAM.h b/gtsam/hybrid/HybridGaussianISAM.h index d01d04862..35cbd6ecd 100644 --- a/gtsam/hybrid/HybridGaussianISAM.h +++ b/gtsam/hybrid/HybridGaussianISAM.h @@ -53,6 +53,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { void updateInternal( const HybridGaussianFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const boost::optional& maxNrLeaves = boost::none, const boost::optional& ordering = boost::none, const HybridBayesTree::Eliminate& function = HybridBayesTree::EliminationTraitsType::DefaultEliminate); @@ -62,20 +63,15 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { * @brief Perform update step with new factors. * * @param newFactors Factor graph of new factors to add and eliminate. + * @param maxNrLeaves The maximum number of leaves to keep after pruning. + * @param ordering Custom elimination ordering. * @param function Elimination function. */ void update(const HybridGaussianFactorGraph& newFactors, + const boost::optional& maxNrLeaves = boost::none, const boost::optional& ordering = boost::none, const HybridBayesTree::Eliminate& function = HybridBayesTree::EliminationTraitsType::DefaultEliminate); - - /** - * @brief - * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves - */ - void prune(const Key& root, const size_t maxNumberLeaves); }; /// traits diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 7725742cf..422c200a4 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -31,9 +31,7 @@ template class EliminatableClusterTree; struct HybridConstructorTraversalData { - typedef - typename JunctionTree::Node - Node; + typedef HybridJunctionTree::Node Node; typedef typename JunctionTree::sharedNode sharedNode; @@ -62,6 +60,7 @@ struct HybridConstructorTraversalData { data.junctionTreeNode = boost::make_shared(node->key, node->factors); parentData.junctionTreeNode->addChild(data.junctionTreeNode); + // Add all the discrete keys in the hybrid factors to the current data for (HybridFactor::shared_ptr& f : node->factors) { for (auto& k : f->discreteKeys()) { data.discreteKeys.insert(k.first); @@ -72,8 +71,8 @@ struct HybridConstructorTraversalData { } // Post-order visitor function - static void ConstructorTraversalVisitorPostAlg2( - const boost::shared_ptr& ETreeNode, + static void ConstructorTraversalVisitorPost( + const boost::shared_ptr& node, const HybridConstructorTraversalData& data) { // In this post-order visitor, we combine the symbolic elimination results // from the elimination tree children and symbolically eliminate the current @@ -86,15 +85,15 @@ struct HybridConstructorTraversalData { // Do symbolic elimination for this node SymbolicFactors symbolicFactors; - symbolicFactors.reserve(ETreeNode->factors.size() + + symbolicFactors.reserve(node->factors.size() + data.childSymbolicFactors.size()); // Add ETree node factors - symbolicFactors += ETreeNode->factors; + symbolicFactors += node->factors; // Add symbolic factors passed up from children symbolicFactors += data.childSymbolicFactors; Ordering keyAsOrdering; - keyAsOrdering.push_back(ETreeNode->key); + keyAsOrdering.push_back(node->key); SymbolicConditional::shared_ptr conditional; SymbolicFactor::shared_ptr separatorFactor; boost::tie(conditional, separatorFactor) = @@ -105,19 +104,19 @@ struct HybridConstructorTraversalData { data.parentData->childSymbolicFactors.push_back(separatorFactor); data.parentData->discreteKeys.merge(data.discreteKeys); - sharedNode node = data.junctionTreeNode; + sharedNode jt_node = data.junctionTreeNode; const FastVector& childConditionals = data.childSymbolicConditionals; - node->problemSize_ = (int)(conditional->size() * symbolicFactors.size()); + jt_node->problemSize_ = (int)(conditional->size() * symbolicFactors.size()); // Merge our children if they are in our clique - if our conditional has // exactly one fewer parent than our child's conditional. const size_t nrParents = conditional->nrParents(); - const size_t nrChildren = node->nrChildren(); + const size_t nrChildren = jt_node->nrChildren(); assert(childConditionals.size() == nrChildren); // decide which children to merge, as index into children - std::vector nrChildrenFrontals = node->nrFrontalsOfChildren(); + std::vector nrChildrenFrontals = jt_node->nrFrontalsOfChildren(); std::vector merge(nrChildren, false); size_t nrFrontals = 1; for (size_t i = 0; i < nrChildren; i++) { @@ -137,7 +136,7 @@ struct HybridConstructorTraversalData { } // now really merge - node->mergeChildren(merge); + jt_node->mergeChildren(merge); } }; @@ -161,7 +160,7 @@ HybridJunctionTree::HybridJunctionTree( // the junction tree roots treeTraversal::DepthFirstForest(eliminationTree, rootData, Data::ConstructorTraversalVisitorPre, - Data::ConstructorTraversalVisitorPostAlg2); + Data::ConstructorTraversalVisitorPost); // Assign roots from the dummy node this->addChildrenAsRoots(rootData.junctionTreeNode); diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp deleted file mode 100644 index a322a8177..000000000 --- a/gtsam/hybrid/HybridLookupDAG.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file DiscreteLookupDAG.cpp - * @date Aug, 2022 - * @author Shangjie Xue - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -using std::pair; -using std::vector; - -namespace gtsam { - -/* ************************************************************************** */ -void HybridLookupTable::argmaxInPlace(HybridValues* values) const { - // For discrete conditional, uses argmaxInPlace() method in - // DiscreteLookupTable. - if (isDiscrete()) { - boost::static_pointer_cast(inner_)->argmaxInPlace( - &(values->discrete)); - } else if (isContinuous()) { - // For Gaussian conditional, uses solve() method in GaussianConditional. - values->continuous.insert( - boost::static_pointer_cast(inner_)->solve( - values->continuous)); - } else if (isHybrid()) { - // For hybrid conditional, since children should not contain discrete - // variable, we can condition on the discrete variable in the parents and - // solve the resulting GaussianConditional. - auto conditional = - boost::static_pointer_cast(inner_)->conditionals()( - values->discrete); - values->continuous.insert(conditional->solve(values->continuous)); - } -} - -/* ************************************************************************** */ -HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { - HybridLookupDAG dag; - for (auto&& conditional : bayesNet) { - HybridLookupTable hlt(*conditional); - dag.push_back(hlt); - } - return dag; -} - -/* ************************************************************************** */ -HybridValues HybridLookupDAG::argmax(HybridValues result) const { - // Argmax each node in turn in topological sort order (parents first). - for (auto lookupTable : boost::adaptors::reverse(*this)) - lookupTable->argmaxInPlace(&result); - return result; -} - -} // namespace gtsam diff --git a/gtsam/hybrid/HybridLookupDAG.h b/gtsam/hybrid/HybridLookupDAG.h deleted file mode 100644 index cc1c58c58..000000000 --- a/gtsam/hybrid/HybridLookupDAG.h +++ /dev/null @@ -1,119 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file HybridLookupDAG.h - * @date Aug, 2022 - * @author Shangjie Xue - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace gtsam { - -/** - * @brief HybridLookupTable table for max-product - * - * Similar to DiscreteLookupTable, inherits from hybrid conditional for - * convenience. Is used in the max-product algorithm. - */ -class GTSAM_EXPORT HybridLookupTable : public HybridConditional { - public: - using Base = HybridConditional; - using This = HybridLookupTable; - using shared_ptr = boost::shared_ptr; - using BaseConditional = Conditional; - - /** - * @brief Construct a new Hybrid Lookup Table object form a HybridConditional. - * - * @param conditional input hybrid conditional - */ - HybridLookupTable(HybridConditional& conditional) : Base(conditional){}; - - /** - * @brief Calculate assignment for frontal variables that maximizes value. - * @param (in/out) parentsValues Known assignments for the parents. - */ - void argmaxInPlace(HybridValues* parentsValues) const; -}; - -/** A DAG made from hybrid lookup tables, as defined above. Similar to - * DiscreteLookupDAG */ -class GTSAM_EXPORT HybridLookupDAG : public BayesNet { - public: - using Base = BayesNet; - using This = HybridLookupDAG; - using shared_ptr = boost::shared_ptr; - - /// @name Standard Constructors - /// @{ - - /// Construct empty DAG. - HybridLookupDAG() {} - - /// Create from BayesNet with LookupTables - static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet); - - /// Destructor - virtual ~HybridLookupDAG() {} - - /// @} - - /// @name Standard Interface - /// @{ - - /** Add a DiscreteLookupTable */ - template - void add(Args&&... args) { - emplace_shared(std::forward(args)...); - } - - /** - * @brief argmax by back-substitution, optionally given certain variables. - * - * Assumes the DAG is reverse topologically sorted, i.e. last - * conditional will be optimized first *and* that the - * DAG does not contain any conditionals for the given variables. If the DAG - * resulted from eliminating a factor graph, this is true for the elimination - * ordering. - * - * @return given assignment extended w. optimal assignment for all variables. - */ - HybridValues argmax(HybridValues given = HybridValues()) const; - /// @} - - private: - /** Serialization function */ - friend class boost::serialization::access; - template - void serialize(ARCHIVE& ar, const unsigned int /*version*/) { - ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); - } -}; - -// traits -template <> -struct traits : public Testable {}; - -} // namespace gtsam diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index a4218593b..3a3bf720b 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -27,8 +27,7 @@ void HybridNonlinearFactorGraph::add( } /* ************************************************************************* */ -void HybridNonlinearFactorGraph::add( - boost::shared_ptr factor) { +void HybridNonlinearFactorGraph::add(boost::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } @@ -49,12 +48,12 @@ void HybridNonlinearFactorGraph::print(const std::string& s, } /* ************************************************************************* */ -HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize( +HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( const Values& continuousValues) const { // create an empty linear FG - HybridGaussianFactorGraph linearFG; + auto linearFG = boost::make_shared(); - linearFG.reserve(size()); + linearFG->reserve(size()); // linearize all hybrid factors for (auto&& factor : factors_) { @@ -66,9 +65,9 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize( if (factor->isHybrid()) { // Check if it is a nonlinear mixture factor if (auto nlmf = boost::dynamic_pointer_cast(factor)) { - linearFG.push_back(nlmf->linearize(continuousValues)); + linearFG->push_back(nlmf->linearize(continuousValues)); } else { - linearFG.push_back(factor); + linearFG->push_back(factor); } // Now check if the factor is a continuous only factor. @@ -80,18 +79,18 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize( boost::dynamic_pointer_cast(nlhf->inner())) { auto hgf = boost::make_shared( nlf->linearize(continuousValues)); - linearFG.push_back(hgf); + linearFG->push_back(hgf); } else { - linearFG.push_back(factor); + linearFG->push_back(factor); } // Finally if nothing else, we are discrete-only which doesn't need // lineariztion. } else { - linearFG.push_back(factor); + linearFG->push_back(factor); } } else { - linearFG.push_back(GaussianFactor::shared_ptr()); + linearFG->push_back(GaussianFactor::shared_ptr()); } } return linearFG; diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 7a19c7755..b48e8bb5c 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -42,6 +42,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { using IsNonlinear = typename std::enable_if< std::is_base_of::value>::type; + /// Check if T has a value_type derived from FactorType. + template + using HasDerivedValueType = typename std::enable_if< + std::is_base_of::value>::type; + + /// Check if T has a pointer type derived from FactorType. + template + using HasDerivedElementType = typename std::enable_if::value>::type; + public: using Base = HybridFactorGraph; using This = HybridNonlinearFactorGraph; ///< this class @@ -109,6 +119,21 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { } } + /** + * Push back many factors as shared_ptr's in a container (factors are not + * copied) + */ + template + HasDerivedElementType push_back(const CONTAINER& container) { + Base::push_back(container.begin(), container.end()); + } + + /// Push back non-pointer objects in a container (factors are copied). + template + HasDerivedValueType push_back(const CONTAINER& container) { + Base::push_back(container.begin(), container.end()); + } + /// Add a nonlinear factor as a shared ptr. void add(boost::shared_ptr factor); @@ -127,7 +152,8 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { * @param continuousValues: Dictionary of continuous values. * @return HybridGaussianFactorGraph::shared_ptr */ - HybridGaussianFactorGraph linearize(const Values& continuousValues) const; + HybridGaussianFactorGraph::shared_ptr linearize( + const Values& continuousValues) const; }; template <> diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp new file mode 100644 index 000000000..57e0daf8d --- /dev/null +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -0,0 +1,114 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridNonlinearISAM.cpp + * @date Sep 12, 2022 + * @author Varun Agrawal + */ + +#include +#include +#include + +#include + +using namespace std; + +namespace gtsam { + +/* ************************************************************************* */ +void HybridNonlinearISAM::saveGraph(const string& s, + const KeyFormatter& keyFormatter) const { + isam_.saveGraph(s, keyFormatter); +} + +/* ************************************************************************* */ +void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors, + const Values& initialValues, + const boost::optional& maxNrLeaves, + const boost::optional& ordering) { + if (newFactors.size() > 0) { + // Reorder and relinearize every reorderInterval updates + if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) { + reorder_relinearize(); + reorderCounter_ = 0; + } + + factors_.push_back(newFactors); + + // Linearize new factors and insert them + // TODO: optimize for whole config? + linPoint_.insert(initialValues); + + boost::shared_ptr linearizedNewFactors = + newFactors.linearize(linPoint_); + + // Update ISAM + isam_.update(*linearizedNewFactors, maxNrLeaves, ordering, + eliminationFunction_); + } +} + +/* ************************************************************************* */ +void HybridNonlinearISAM::reorder_relinearize() { + if (factors_.size() > 0) { + // Obtain the new linearization point + const Values newLinPoint = estimate(); + + isam_.clear(); + + // Just recreate the whole BayesTree + // TODO: allow for constrained ordering here + // TODO: decouple relinearization and reordering to avoid + isam_.update(*factors_.linearize(newLinPoint), boost::none, boost::none, + eliminationFunction_); + + // Update linearization point + linPoint_ = newLinPoint; + } +} + +/* ************************************************************************* */ +Values HybridNonlinearISAM::estimate() { + Values result; + if (isam_.size() > 0) { + HybridValues values = isam_.optimize(); + assignment_ = values.discrete(); + return linPoint_.retract(values.continuous()); + } else { + return linPoint_; + } +} + +// /* ************************************************************************* +// */ Matrix HybridNonlinearISAM::marginalCovariance(Key key) const { +// return isam_.marginalCovariance(key); +// } + +/* ************************************************************************* */ +void HybridNonlinearISAM::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << s << "ReorderInterval: " << reorderInterval_ + << " Current Count: " << reorderCounter_ << endl; + isam_.print("HybridGaussianISAM:\n", keyFormatter); + linPoint_.print("Linearization Point:\n", keyFormatter); + factors_.print("Nonlinear Graph:\n", keyFormatter); +} + +/* ************************************************************************* */ +void HybridNonlinearISAM::printStats() const { + isam_.getCliqueData().getStats().print(); +} + +/* ************************************************************************* */ + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridNonlinearISAM.h b/gtsam/hybrid/HybridNonlinearISAM.h new file mode 100644 index 000000000..47aa81c55 --- /dev/null +++ b/gtsam/hybrid/HybridNonlinearISAM.h @@ -0,0 +1,131 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridNonlinearISAM.h + * @date Sep 12, 2022 + * @author Varun Agrawal + */ + +#pragma once + +#include +#include + +namespace gtsam { +/** + * Wrapper class to manage ISAM in a nonlinear context + */ +class GTSAM_EXPORT HybridNonlinearISAM { + protected: + /** The internal iSAM object */ + gtsam::HybridGaussianISAM isam_; + + /** The current linearization point */ + Values linPoint_; + + /// The discrete assignment + DiscreteValues assignment_; + + /** The original factors, used when relinearizing */ + HybridNonlinearFactorGraph factors_; + + /** The reordering interval and counter */ + int reorderInterval_; + int reorderCounter_; + + /** The elimination function */ + HybridGaussianFactorGraph::Eliminate eliminationFunction_; + + public: + /// @name Standard Constructors + /// @{ + + /** + * Periodically reorder and relinearize + * @param reorderInterval is the number of updates between reorderings, + * 0 never reorders (and is dangerous for memory consumption) + * 1 (default) reorders every time, in worse case is batch every update + * typical values are 50 or 100 + */ + HybridNonlinearISAM( + int reorderInterval = 1, + const HybridGaussianFactorGraph::Eliminate& eliminationFunction = + HybridGaussianFactorGraph::EliminationTraitsType::DefaultEliminate) + : reorderInterval_(reorderInterval), + reorderCounter_(0), + eliminationFunction_(eliminationFunction) {} + + /// @} + /// @name Standard Interface + /// @{ + + /** Return the current solution estimate */ + Values estimate(); + + // /** find the marginal covariance for a single variable */ + // Matrix marginalCovariance(Key key) const; + + // access + + /** access the underlying bayes tree */ + const HybridGaussianISAM& bayesTree() const { return isam_; } + + /** + * @brief Prune the underlying Bayes tree. + * + * @param maxNumberLeaves The max number of leaf nodes to keep. + */ + void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); } + + /** Return the current linearization point */ + const Values& getLinearizationPoint() const { return linPoint_; } + + /** Return the current discrete assignment */ + const DiscreteValues& getAssignment() const { return assignment_; } + + /** get underlying nonlinear graph */ + const HybridNonlinearFactorGraph& getFactorsUnsafe() const { + return factors_; + } + + /** get counters */ + int reorderInterval() const { return reorderInterval_; } ///< TODO: comment + int reorderCounter() const { return reorderCounter_; } ///< TODO: comment + + /** prints out all contents of the system */ + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /** prints out clique statistics */ + void printStats() const; + + /** saves the Tree to a text file in GraphViz format */ + void saveGraph(const std::string& s, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} + /// @name Advanced Interface + /// @{ + + /** Add new factors along with their initial linearization points */ + void update(const HybridNonlinearFactorGraph& newFactors, + const Values& initialValues, + const boost::optional& maxNrLeaves = boost::none, + const boost::optional& ordering = boost::none); + + /** Relinearization and reordering of variables */ + void reorder_relinearize(); + + /// @} +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index 5e1bd4164..4928f9384 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -31,60 +31,78 @@ namespace gtsam { /** - * HybridValues represents a collection of DiscreteValues and VectorValues. It - * is typically used to store the variables of a HybridGaussianFactorGraph. + * HybridValues represents a collection of DiscreteValues and VectorValues. + * It is typically used to store the variables of a HybridGaussianFactorGraph. * Optimizing a HybridGaussianBayesNet returns this class. */ class GTSAM_EXPORT HybridValues { - public: + private: // DiscreteValue stored the discrete components of the HybridValues. - DiscreteValues discrete; + DiscreteValues discrete_; // VectorValue stored the continuous components of the HybridValues. - VectorValues continuous; + VectorValues continuous_; - // Default constructor creates an empty HybridValues. - HybridValues() : discrete(), continuous(){}; + public: + /// @name Standard Constructors + /// @{ - // Construct from DiscreteValues and VectorValues. + /// Default constructor creates an empty HybridValues. + HybridValues() = default; + + /// Construct from DiscreteValues and VectorValues. HybridValues(const DiscreteValues& dv, const VectorValues& cv) - : discrete(dv), continuous(cv){}; + : discrete_(dv), continuous_(cv){}; - // print required by Testable for unit testing + /// @} + /// @name Testable + /// @{ + + /// print required by Testable for unit testing void print(const std::string& s = "HybridValues", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::cout << s << ": \n"; - discrete.print(" Discrete", keyFormatter); // print discrete components - continuous.print(" Continuous", - keyFormatter); // print continuous components + discrete_.print(" Discrete", keyFormatter); // print discrete components + continuous_.print(" Continuous", + keyFormatter); // print continuous components }; - // equals required by Testable for unit testing + /// equals required by Testable for unit testing bool equals(const HybridValues& other, double tol = 1e-9) const { - return discrete.equals(other.discrete, tol) && - continuous.equals(other.continuous, tol); + return discrete_.equals(other.discrete_, tol) && + continuous_.equals(other.continuous_, tol); } - // Check whether a variable with key \c j exists in DiscreteValue. - bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; + /// @} + /// @name Interface + /// @{ - // Check whether a variable with key \c j exists in VectorValue. - bool existsVector(Key j) { return continuous.exists(j); }; + /// Return the discrete MPE assignment + DiscreteValues discrete() const { return discrete_; } - // Check whether a variable with key \c j exists. + /// Return the delta update for the continuous vectors + VectorValues continuous() const { return continuous_; } + + /// Check whether a variable with key \c j exists in DiscreteValue. + bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }; + + /// Check whether a variable with key \c j exists in VectorValue. + bool existsVector(Key j) { return continuous_.exists(j); }; + + /// Check whether a variable with key \c j exists. bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; /** Insert a discrete \c value with key \c j. Replaces the existing value if * the key \c j is already used. * @param value The vector to be inserted. * @param j The index with which the value will be associated. */ - void insert(Key j, int value) { discrete[j] = value; }; + void insert(Key j, int value) { discrete_[j] = value; }; /** Insert a vector \c value with key \c j. Throws an invalid_argument * exception if the key \c j is already used. * @param value The vector to be inserted. * @param j The index with which the value will be associated. */ - void insert(Key j, const Vector& value) { continuous.insert(j, value); } + void insert(Key j, const Vector& value) { continuous_.insert(j, value); } // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h @@ -92,13 +110,13 @@ class GTSAM_EXPORT HybridValues { * Read/write access to the discrete value with key \c j, throws * std::out_of_range if \c j does not exist. */ - size_t& atDiscrete(Key j) { return discrete.at(j); }; + size_t& atDiscrete(Key j) { return discrete_.at(j); }; /** * Read/write access to the vector value with key \c j, throws * std::out_of_range if \c j does not exist. */ - Vector& at(Key j) { return continuous.at(j); }; + Vector& at(Key j) { return continuous_.at(j); }; /// @name Wrapper support /// @{ @@ -112,8 +130,8 @@ class GTSAM_EXPORT HybridValues { std::string html( const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::stringstream ss; - ss << this->discrete.html(keyFormatter); - ss << this->continuous.html(keyFormatter); + ss << this->discrete_.html(keyFormatter); + ss << this->continuous_.html(keyFormatter); return ss.str(); }; diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 3cd21e32e..5e7337d0c 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -100,11 +100,23 @@ class MixtureFactor : public HybridFactor { bool normalized = false) : Base(keys, discreteKeys), normalized_(normalized) { std::vector nonlinear_factors; + KeySet continuous_keys_set(keys.begin(), keys.end()); + KeySet factor_keys_set; for (auto&& f : factors) { + // Insert all factor continuous keys in the continuous keys set. + std::copy(f->keys().begin(), f->keys().end(), + std::inserter(factor_keys_set, factor_keys_set.end())); + nonlinear_factors.push_back( boost::dynamic_pointer_cast(f)); } factors_ = Factors(discreteKeys, nonlinear_factors); + + if (continuous_keys_set != factor_keys_set) { + throw std::runtime_error( + "The specified continuous keys and the keys in the factors don't " + "match!"); + } } ~MixtureFactor() = default; diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index aa63259d9..86029a48a 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -6,8 +6,8 @@ namespace gtsam { #include class HybridValues { - gtsam::DiscreteValues discrete; - gtsam::VectorValues continuous; + gtsam::DiscreteValues discrete() const; + gtsam::VectorValues continuous() const; HybridValues(); HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv); void print(string s = "HybridValues", @@ -99,6 +99,8 @@ class HybridBayesTree { bool empty() const; const HybridBayesTreeClique* operator[](size_t j) const; + gtsam::HybridValues optimize() const; + string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; }; diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 52ef0439e..3ae8f0bb1 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -115,7 +115,6 @@ inline std::pair> makeBinaryOrdering( /* *************************************************************************** */ using MotionModel = BetweenFactor; -// using MotionMixture = MixtureFactor; // Test fixture with switching network. struct Switching { @@ -125,12 +124,15 @@ struct Switching { HybridGaussianFactorGraph linearizedFactorGraph; Values linearizationPoint; - /// Create with given number of time steps. + /** + * @brief Create with given number of time steps. + * + * @param K The total number of timesteps. + * @param between_sigma The stddev between poses. + * @param prior_sigma The stddev on priors (also used for measurements). + */ Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1) : K(K) { - using symbol_shorthand::M; - using symbol_shorthand::X; - // Create DiscreteKeys for binary K modes, modes[0] will not be used. for (size_t k = 0; k <= K; k++) { modes.emplace_back(M(k), 2); @@ -145,7 +147,7 @@ struct Switching { // Add "motion models". for (size_t k = 1; k < K; k++) { KeyVector keys = {X(k), X(k + 1)}; - auto motion_models = motionModels(k); + auto motion_models = motionModels(k, between_sigma); std::vector components; for (auto &&f : motion_models) { components.push_back(boost::dynamic_pointer_cast(f)); @@ -155,7 +157,7 @@ struct Switching { } // Add measurement factors - auto measurement_noise = noiseModel::Isotropic::Sigma(1, 0.1); + auto measurement_noise = noiseModel::Isotropic::Sigma(1, prior_sigma); for (size_t k = 2; k <= K; k++) { nonlinearFactorGraph.emplace_nonlinear>( X(k), 1.0 * (k - 1), measurement_noise); @@ -169,15 +171,14 @@ struct Switching { linearizationPoint.insert(X(k), static_cast(k)); } - linearizedFactorGraph = nonlinearFactorGraph.linearize(linearizationPoint); + // The ground truth is robot moving forward + // and one less than the linearization point + linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint); } // Create motion models for a given time step static std::vector motionModels(size_t k, double sigma = 1.0) { - using symbol_shorthand::M; - using symbol_shorthand::X; - auto noise_model = noiseModel::Isotropic::Sigma(1, sigma); auto still = boost::make_shared(X(k), X(k + 1), 0.0, noise_model), diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f3db83955..5885fdcdc 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,7 +18,10 @@ * @date December 2021 */ +#include #include +#include +#include #include "Switching.h" @@ -27,6 +30,8 @@ using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; @@ -47,6 +52,20 @@ TEST(HybridBayesNet, Creation) { EXPECT(df.equals(expected)); } +/* ****************************************************************************/ +// Test adding a bayes net to another one. +TEST(HybridBayesNet, Add) { + HybridBayesNet bayesNet; + + bayesNet.add(Asia, "99/1"); + + DiscreteConditional expected(Asia, "99/1"); + + HybridBayesNet other; + other.push_back(bayesNet); + EXPECT(bayesNet.equals(other)); +} + /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { @@ -72,19 +91,128 @@ TEST(HybridBayesNet, Choose) { EXPECT_LONGS_EQUAL(4, gbn.size()); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(0)))(assignment), + hybridBayesNet->atMixture(0)))(assignment), *gbn.at(0))); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(1)))(assignment), + hybridBayesNet->atMixture(1)))(assignment), *gbn.at(1))); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(2)))(assignment), + hybridBayesNet->atMixture(2)))(assignment), *gbn.at(2))); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(3)))(assignment), + hybridBayesNet->atMixture(3)))(assignment), *gbn.at(3))); } +/* ****************************************************************************/ +// Test bayes net optimize +TEST(HybridBayesNet, OptimizeAssignment) { + Switching s(4); + + Ordering ordering; + for (auto&& kvp : s.linearizationPoint) { + ordering += kvp.key; + } + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + DiscreteValues assignment; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + assignment[M(3)] = 1; + + VectorValues delta = hybridBayesNet->optimize(assignment); + + // The linearization point has the same value as the key index, + // e.g. X(1) = 1, X(2) = 2, + // but the factors specify X(k) = k-1, so delta should be -1. + VectorValues expected_delta; + expected_delta.insert(make_pair(X(1), -Vector1::Ones())); + expected_delta.insert(make_pair(X(2), -Vector1::Ones())); + expected_delta.insert(make_pair(X(3), -Vector1::Ones())); + expected_delta.insert(make_pair(X(4), -Vector1::Ones())); + + EXPECT(assert_equal(expected_delta, delta)); +} + +/* ****************************************************************************/ +// Test bayes net optimize +TEST(HybridBayesNet, Optimize) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + + DiscreteValues expectedAssignment; + expectedAssignment[M(1)] = 1; + expectedAssignment[M(2)] = 0; + expectedAssignment[M(3)] = 1; + EXPECT(assert_equal(expectedAssignment, delta.discrete())); + + VectorValues expectedValues; + expectedValues.insert(X(1), -0.999904 * Vector1::Ones()); + expectedValues.insert(X(2), -0.99029 * Vector1::Ones()); + expectedValues.insert(X(3), -1.00971 * Vector1::Ones()); + expectedValues.insert(X(4), -1.0001 * Vector1::Ones()); + + EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); +} + +/* ****************************************************************************/ +// Test bayes net multifrontal optimize +TEST(HybridBayesNet, OptimizeMultifrontal) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree::shared_ptr hybridBayesTree = + s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering); + HybridValues delta = hybridBayesTree->optimize(); + + VectorValues expectedValues; + expectedValues.insert(X(1), -0.999904 * Vector1::Ones()); + expectedValues.insert(X(2), -0.99029 * Vector1::Ones()); + expectedValues.insert(X(3), -1.00971 * Vector1::Ones()); + expectedValues.insert(X(4), -1.0001 * Vector1::Ones()); + + EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); +} + +/* ****************************************************************************/ +// Test bayes net pruning +TEST(HybridBayesNet, Prune) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + + auto prunedBayesNet = hybridBayesNet->prune(2); + HybridValues pruned_delta = prunedBayesNet.optimize(); + + EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); + EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); +} + +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridBayesNet, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp new file mode 100644 index 000000000..0908b8cb5 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -0,0 +1,166 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridBayesTree.cpp + * @brief Unit tests for HybridBayesTree + * @author Varun Agrawal + * @date August 2022 + */ + +#include +#include +#include +#include + +#include "Switching.h" + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +/* ****************************************************************************/ +// Test for optimizing a HybridBayesTree with a given assignment. +TEST(HybridBayesTree, OptimizeAssignment) { + Switching s(4); + + HybridGaussianISAM isam; + HybridGaussianFactorGraph graph1; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(1), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(s.linearizedFactorGraph.at(0)); + for (size_t i = 4; i <= 7; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the discrete factors + for (size_t i = 7; i <= 9; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + isam.update(graph1); + + DiscreteValues assignment; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + assignment[M(3)] = 1; + + VectorValues delta = isam.optimize(assignment); + + // The linearization point has the same value as the key index, + // e.g. X(1) = 1, X(2) = 2, + // but the factors specify X(k) = k-1, so delta should be -1. + VectorValues expected_delta; + expected_delta.insert(make_pair(X(1), -Vector1::Ones())); + expected_delta.insert(make_pair(X(2), -Vector1::Ones())); + expected_delta.insert(make_pair(X(3), -Vector1::Ones())); + expected_delta.insert(make_pair(X(4), -Vector1::Ones())); + + EXPECT(assert_equal(expected_delta, delta)); + + // Create ordering. + Ordering ordering; + for (size_t k = 1; k <= s.K; k++) ordering += X(k); + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + GaussianBayesNet gbn = hybridBayesNet->choose(assignment); + VectorValues expected = gbn.optimize(); + + EXPECT(assert_equal(expected, delta)); +} + +/* ****************************************************************************/ +// Test for optimizing a HybridBayesTree. +TEST(HybridBayesTree, Optimize) { + Switching s(4); + + HybridGaussianISAM isam; + HybridGaussianFactorGraph graph1; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(1), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(s.linearizedFactorGraph.at(0)); + for (size_t i = 4; i <= 6; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the discrete factors + for (size_t i = 7; i <= 9; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + isam.update(graph1); + + HybridValues delta = isam.optimize(); + + // Create ordering. + Ordering ordering; + for (size_t k = 1; k <= s.K; k++) ordering += X(k); + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + DiscreteFactorGraph dfg; + for (auto&& f : *remainingFactorGraph) { + auto factor = dynamic_pointer_cast(f); + dfg.push_back( + boost::dynamic_pointer_cast(factor->inner())); + } + + DiscreteValues expectedMPE = dfg.optimize(); + VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE); + + EXPECT(assert_equal(expectedMPE, delta.discrete())); + EXPECT(assert_equal(expectedValues, delta.continuous())); +} + +/* ****************************************************************************/ +// Test HybridBayesTree serialization. +TEST(HybridBayesTree, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree hbt = + *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); + + using namespace gtsam::serializationTestHelpers; + EXPECT(equalsObj(hbt)); + EXPECT(equalsXML(hbt)); + EXPECT(equalsBinary(hbt)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp similarity index 90% rename from gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp rename to gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index e4bd0e084..d199d7611 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { hfg.add(DecisionTreeFactor(m1, {2, 8})); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); - HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal( - Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)})); + HybridBayesTree::shared_ptr result = + hfg.eliminateMultifrontal(hfg.getHybridOrdering()); // The bayes tree should have 3 cliques EXPECT_LONGS_EQUAL(3, result->size()); @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) { hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); // Get a constrained ordering keeping c1 last - auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)}); + auto ordering_full = hfg.getHybridOrdering(); // Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1) HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full); @@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } HybridBayesNet::shared_ptr hbn; HybridGaussianFactorGraph::shared_ptr remaining; - std::tie(hbn, remaining) = - hfg->eliminatePartialSequential(ordering_partial); + std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial); EXPECT_LONGS_EQUAL(14, hbn->size()); EXPECT_LONGS_EQUAL(11, remaining->size()); @@ -501,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } } +/* ************************************************************************* */ TEST(HybridGaussianFactorGraph, optimize) { HybridGaussianFactorGraph hfg; @@ -522,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) { EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); } + +/* ************************************************************************* */ +// Test adding of gaussian conditional and re-elimination. +TEST(HybridGaussianFactorGraph, Conditionals) { + Switching switching(4); + HybridGaussianFactorGraph hfg; + + hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1) + Ordering ordering; + ordering.push_back(X(1)); + HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering); + + hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1) + hfg.push_back(*bayes_net); + hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2) + hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1) + ordering.push_back(X(2)); + ordering.push_back(X(3)); + ordering.push_back(M(1)); + ordering.push_back(M(2)); + + bayes_net = hfg.eliminateSequential(ordering); + + HybridValues result = bayes_net->optimize(); + + Values expected_continuous; + expected_continuous.insert(X(1), 0); + expected_continuous.insert(X(2), 1); + expected_continuous.insert(X(3), 2); + expected_continuous.insert(X(4), 4); + Values result_continuous = + switching.linearizationPoint.retract(result.continuous()); + EXPECT(assert_equal(expected_continuous, result_continuous)); + + DiscreteValues expected_discrete; + expected_discrete[M(1)] = 1; + expected_discrete[M(2)] = 1; + EXPECT(assert_equal(expected_discrete, result.discrete())); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp similarity index 98% rename from gtsam/hybrid/tests/testHybridIncremental.cpp rename to gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 4449aba0b..a5e3903d9 100644 --- a/gtsam/hybrid/tests/testHybridIncremental.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) { size_t maxNrLeaves = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxNrLeaves); + incrementalHybrid.prune(maxNrLeaves); /* unpruned factor is: @@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning size_t maxComponents = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { EXPECT_LONGS_EQUAL( 2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( - 4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); + 3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( 5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( @@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning a second time. incrementalHybrid.update(graph2); - incrementalHybrid.prune(M(4), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -399,7 +399,7 @@ TEST(HybridGaussianISAM, NonTrivial) { initial.insert(Z(0), Pose2(0.0, 2.0, 0.0)); initial.insert(W(0), Pose2(0.0, 3.0, 0.0)); - HybridGaussianFactorGraph gfg = fg.linearize(initial); + HybridGaussianFactorGraph gfg = *fg.linearize(initial); fg = HybridNonlinearFactorGraph(); HybridGaussianISAM inc; @@ -444,7 +444,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // The leg link did not move so we set the expected pose accordingly. initial.insert(W(1), Pose2(0.0, 3.0, 0.0)); - gfg = fg.linearize(initial); + gfg = *fg.linearize(initial); fg = HybridNonlinearFactorGraph(); // Update without pruning @@ -483,7 +483,7 @@ TEST(HybridGaussianISAM, NonTrivial) { initial.insert(Z(2), Pose2(2.0, 2.0, 0.0)); initial.insert(W(2), Pose2(0.0, 3.0, 0.0)); - gfg = fg.linearize(initial); + gfg = *fg.linearize(initial); fg = HybridNonlinearFactorGraph(); // Now we prune! @@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(gfg); - inc.prune(M(2), 2); + inc.prune(2); /*************** Run Round 4 ***************/ // Add odometry factor with discrete modes. @@ -526,12 +526,12 @@ TEST(HybridGaussianISAM, NonTrivial) { initial.insert(Z(3), Pose2(3.0, 2.0, 0.0)); initial.insert(W(3), Pose2(0.0, 3.0, 0.0)); - gfg = fg.linearize(initial); + gfg = *fg.linearize(initial); fg = HybridNonlinearFactorGraph(); // Keep pruning! inc.update(gfg); - inc.prune(M(3), 3); + inc.prune(3); // The final discrete graph should not be empty since we have eliminated // all continuous variables. diff --git a/gtsam/hybrid/tests/testHybridLookupDAG.cpp b/gtsam/hybrid/tests/testHybridLookupDAG.cpp deleted file mode 100644 index 0ab012d10..000000000 --- a/gtsam/hybrid/tests/testHybridLookupDAG.cpp +++ /dev/null @@ -1,272 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testHybridLookupDAG.cpp - * @date Aug, 2022 - * @author Shangjie Xue - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// Include for test suite -#include - -#include - -using namespace std; -using namespace gtsam; -using noiseModel::Isotropic; -using symbol_shorthand::M; -using symbol_shorthand::X; - -TEST(HybridLookupTable, basics) { - // create a conditional gaussian node - Matrix S1(2, 2); - S1(0, 0) = 1; - S1(1, 0) = 2; - S1(0, 1) = 3; - S1(1, 1) = 4; - - Matrix S2(2, 2); - S2(0, 0) = 6; - S2(1, 0) = 0.2; - S2(0, 1) = 8; - S2(1, 1) = 0.4; - - Matrix R1(2, 2); - R1(0, 0) = 0.1; - R1(1, 0) = 0.3; - R1(0, 1) = 0.0; - R1(1, 1) = 0.34; - - Matrix R2(2, 2); - R2(0, 0) = 0.1; - R2(1, 0) = 0.3; - R2(0, 1) = 0.0; - R2(1, 1) = 0.34; - - SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - - Vector2 d1(0.2, 0.5), d2(0.5, 0.2); - - auto conditional0 = boost::make_shared(X(1), d1, R1, - X(2), S1, model), - conditional1 = boost::make_shared(X(1), d2, R2, - X(2), S2, model); - - // Create decision tree - DiscreteKey m1(1, 2); - GaussianMixture::Conditionals conditionals( - {m1}, - vector{conditional0, conditional1}); - // GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); - - boost::shared_ptr mixtureFactor( - new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); - - HybridConditional hc(mixtureFactor); - - GaussianMixture::Conditionals conditional2 = - boost::static_pointer_cast(hc.inner())->conditionals(); - - DiscreteValues dv; - dv[1] = 1; - - VectorValues cv; - cv.insert(X(2), Vector2(0.0, 0.0)); - - HybridValues hv(dv, cv); - - // std::cout << conditional2(values).markdown(); - EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6)); - EXPECT(conditional2(dv) == conditionals(dv)); - HybridLookupTable hlt(hc); - - // hlt.argmaxInPlace(&hv); - - HybridLookupDAG dag; - dag.push_back(hlt); - dag.argmax(hv); - - // HybridBayesNet hbn; - // hbn.push_back(hc); - // hbn.optimize(); -} - -TEST(HybridLookupTable, hybrid_argmax) { - Matrix S1(2, 2); - S1(0, 0) = 1; - S1(1, 0) = 0; - S1(0, 1) = 0; - S1(1, 1) = 1; - - Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); - - SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - - auto conditional0 = - boost::make_shared(X(1), d1, S1, model), - conditional1 = - boost::make_shared(X(1), d2, S1, model); - - DiscreteKey m1(1, 2); - GaussianMixture::Conditionals conditionals( - {m1}, - vector{conditional0, conditional1}); - boost::shared_ptr mixtureFactor( - new GaussianMixture({X(1)}, {}, {m1}, conditionals)); - - HybridConditional hc(mixtureFactor); - - DiscreteValues dv; - dv[1] = 1; - VectorValues cv; - // cv.insert(X(2),Vector2(0.0, 0.0)); - HybridValues hv(dv, cv); - - HybridLookupTable hlt(hc); - - hlt.argmaxInPlace(&hv); - - EXPECT(assert_equal(hv.at(X(1)), d2)); -} - -TEST(HybridLookupTable, discrete_argmax) { - DiscreteKey X(0, 2), Y(1, 2); - - auto conditional = boost::make_shared(X | Y = "0/1 3/2"); - - HybridConditional hc(conditional); - - HybridLookupTable hlt(hc); - - DiscreteValues dv; - dv[1] = 0; - VectorValues cv; - // cv.insert(X(2),Vector2(0.0, 0.0)); - HybridValues hv(dv, cv); - - hlt.argmaxInPlace(&hv); - - EXPECT(assert_equal(hv.atDiscrete(0), 1)); - - DecisionTreeFactor f1(X, "2 3"); - auto conditional2 = boost::make_shared(1, f1); - - HybridConditional hc2(conditional2); - - HybridLookupTable hlt2(hc2); - - HybridValues hv2; - - hlt2.argmaxInPlace(&hv2); - - EXPECT(assert_equal(hv2.atDiscrete(0), 1)); -} - -TEST(HybridLookupTable, gaussian_argmax) { - Matrix S1(2, 2); - S1(0, 0) = 1; - S1(1, 0) = 0; - S1(0, 1) = 0; - S1(1, 1) = 1; - - Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); - - SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - - auto conditional = - boost::make_shared(X(1), d1, S1, X(2), -S1, model); - - HybridConditional hc(conditional); - - HybridLookupTable hlt(hc); - - DiscreteValues dv; - // dv[1]=0; - VectorValues cv; - cv.insert(X(2), d2); - HybridValues hv(dv, cv); - - hlt.argmaxInPlace(&hv); - - EXPECT(assert_equal(hv.at(X(1)), d1 + d2)); -} - -TEST(HybridLookupDAG, argmax) { - Matrix S1(2, 2); - S1(0, 0) = 1; - S1(1, 0) = 0; - S1(0, 1) = 0; - S1(1, 1) = 1; - - Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); - - SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - - auto conditional0 = - boost::make_shared(X(2), d1, S1, model), - conditional1 = - boost::make_shared(X(2), d2, S1, model); - - DiscreteKey m1(1, 2); - GaussianMixture::Conditionals conditionals( - {m1}, - vector{conditional0, conditional1}); - boost::shared_ptr mixtureFactor( - new GaussianMixture({X(2)}, {}, {m1}, conditionals)); - HybridConditional hc2(mixtureFactor); - HybridLookupTable hlt2(hc2); - - auto conditional2 = - boost::make_shared(X(1), d1, S1, X(2), -S1, model); - - HybridConditional hc1(conditional2); - HybridLookupTable hlt1(hc1); - - DecisionTreeFactor f1(m1, "2 3"); - auto discrete_conditional = boost::make_shared(1, f1); - - HybridConditional hc3(discrete_conditional); - HybridLookupTable hlt3(hc3); - - HybridLookupDAG dag; - dag.push_back(hlt1); - dag.push_back(hlt2); - dag.push_back(hlt3); - auto hv = dag.argmax(); - - EXPECT(assert_equal(hv.atDiscrete(1), 1)); - EXPECT(assert_equal(hv.at(X(2)), d2)); - EXPECT(assert_equal(hv.at(X(1)), d2 + d1)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 018b017a9..9e93eaba3 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -60,7 +60,7 @@ TEST(HybridFactorGraph, GaussianFactorGraph) { Values linearizationPoint; linearizationPoint.insert(X(0), 0); - HybridGaussianFactorGraph ghfg = fg.linearize(linearizationPoint); + HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint); // Add a factor to the GaussianFactorGraph ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5))); @@ -139,7 +139,7 @@ TEST(HybridGaussianFactorGraph, Resize) { linearizationPoint.insert(X(1), 1); // Generate `HybridGaussianFactorGraph` by linearizing - HybridGaussianFactorGraph gfg = nhfg.linearize(linearizationPoint); + HybridGaussianFactorGraph gfg = *nhfg.linearize(linearizationPoint); EXPECT_LONGS_EQUAL(gfg.size(), 3); @@ -147,6 +147,32 @@ TEST(HybridGaussianFactorGraph, Resize) { EXPECT_LONGS_EQUAL(gfg.size(), 0); } +/*************************************************************************** + * Test that the MixtureFactor reports correctly if the number of continuous + * keys provided do not match the keys in the factors. + */ +TEST(HybridGaussianFactorGraph, MixtureFactor) { + auto nonlinearFactor = boost::make_shared>( + X(0), X(1), 0.0, Isotropic::Sigma(1, 0.1)); + auto discreteFactor = boost::make_shared(); + + auto noise_model = noiseModel::Isotropic::Sigma(1, 1.0); + auto still = boost::make_shared(X(0), X(1), 0.0, noise_model), + moving = boost::make_shared(X(0), X(1), 1.0, noise_model); + + std::vector components = {still, moving}; + + // Check for exception when number of continuous keys are under-specified. + KeyVector contKeys = {X(0)}; + THROWS_EXCEPTION(boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components)); + + // Check for exception when number of continuous keys are too many. + contKeys = {X(0), X(1), X(2)}; + THROWS_EXCEPTION(boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components)); +} + /***************************************************************************** * Test push_back on HFG makes the correct distinction. */ @@ -224,7 +250,7 @@ TEST(HybridFactorGraph, Linearization) { // Linearize here: HybridGaussianFactorGraph actualLinearized = - self.nonlinearFactorGraph.linearize(self.linearizationPoint); + *self.nonlinearFactorGraph.linearize(self.linearizationPoint); EXPECT_LONGS_EQUAL(7, actualLinearized.size()); } @@ -257,14 +283,6 @@ TEST(GaussianElimination, Eliminate_x1) { // Add first hybrid factor factors.push_back(self.linearizedFactorGraph[1]); - // TODO(Varun) remove this block since sum is no longer exposed. - // // Check that sum works: - // auto sum = factors.sum(); - // Assignment mode; - // mode[M(1)] = 1; - // auto actual = sum(mode); // Selects one of 2 modes. - // EXPECT_LONGS_EQUAL(2, actual.size()); // Prior and motion model. - // Eliminate x1 Ordering ordering; ordering += X(1); @@ -289,15 +307,6 @@ TEST(HybridsGaussianElimination, Eliminate_x2) { factors.push_back(self.linearizedFactorGraph[1]); // involves m1 factors.push_back(self.linearizedFactorGraph[2]); // involves m2 - // TODO(Varun) remove this block since sum is no longer exposed. - // // Check that sum works: - // auto sum = factors.sum(); - // Assignment mode; - // mode[M(1)] = 0; - // mode[M(2)] = 1; - // auto actual = sum(mode); // Selects one of 4 mode - // combinations. EXPECT_LONGS_EQUAL(2, actual.size()); // 2 motion models. - // Eliminate x2 Ordering ordering; ordering += X(2); @@ -364,51 +373,10 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); EXPECT(discreteFactor->root_->isLeaf() == false); + + // TODO(Varun) Test emplace_discrete } -// /* -// ****************************************************************************/ -// /// Test the toDecisionTreeFactor method -// TEST(HybridFactorGraph, ToDecisionTreeFactor) { -// size_t K = 3; - -// // Provide tight sigma values so that the errors are visibly different. -// double between_sigma = 5e-8, prior_sigma = 1e-7; - -// Switching self(K, between_sigma, prior_sigma); - -// // Clear out discrete factors since sum() cannot hanldle those -// HybridGaussianFactorGraph linearizedFactorGraph( -// self.linearizedFactorGraph.gaussianGraph(), DiscreteFactorGraph(), -// self.linearizedFactorGraph.dcGraph()); - -// auto decisionTreeFactor = linearizedFactorGraph.toDecisionTreeFactor(); - -// auto allAssignments = -// DiscreteValues::CartesianProduct(linearizedFactorGraph.discreteKeys()); - -// // Get the error of the discrete assignment m1=0, m2=1. -// double actual = (*decisionTreeFactor)(allAssignments[1]); - -// /********************************************/ -// // Create equivalent factor graph for m1=0, m2=1 -// GaussianFactorGraph graph = linearizedFactorGraph.gaussianGraph(); - -// for (auto &p : linearizedFactorGraph.dcGraph()) { -// if (auto mixture = -// boost::dynamic_pointer_cast(p)) { -// graph.add((*mixture)(allAssignments[1])); -// } -// } - -// VectorValues values = graph.optimize(); -// double expected = graph.probPrime(values); -// /********************************************/ -// EXPECT_DOUBLES_EQUAL(expected, actual, 1e-12); -// // REGRESSION: -// EXPECT_DOUBLES_EQUAL(0.6125, actual, 1e-4); -// } - /**************************************************************************** * Test partial elimination */ @@ -428,7 +396,6 @@ TEST(HybridFactorGraph, Partial_Elimination) { linearizedFactorGraph.eliminatePartialSequential(ordering); CHECK(hybridBayesNet); - // GTSAM_PRINT(*hybridBayesNet); // HybridBayesNet EXPECT_LONGS_EQUAL(3, hybridBayesNet->size()); EXPECT(hybridBayesNet->at(0)->frontals() == KeyVector{X(1)}); EXPECT(hybridBayesNet->at(0)->parents() == KeyVector({X(2), M(1)})); @@ -438,7 +405,6 @@ TEST(HybridFactorGraph, Partial_Elimination) { EXPECT(hybridBayesNet->at(2)->parents() == KeyVector({M(1), M(2)})); CHECK(remainingFactorGraph); - // GTSAM_PRINT(*remainingFactorGraph); // HybridFactorGraph EXPECT_LONGS_EQUAL(3, remainingFactorGraph->size()); EXPECT(remainingFactorGraph->at(0)->keys() == KeyVector({M(1)})); EXPECT(remainingFactorGraph->at(1)->keys() == KeyVector({M(2), M(1)})); @@ -721,13 +687,8 @@ TEST(HybridFactorGraph, DefaultDecisionTree) { moving = boost::make_shared(X(0), X(1), odometry, noise_model); std::vector motion_models = {still, moving}; - // TODO(Varun) Make a templated constructor for MixtureFactor which does this? - std::vector components; - for (auto&& f : motion_models) { - components.push_back(boost::dynamic_pointer_cast(f)); - } fg.emplace_hybrid( - contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components); + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, motion_models); // Add Range-Bearing measurements to from X0 to L0 and X1 to L1. // create a noise model for the landmark measurements @@ -757,7 +718,7 @@ TEST(HybridFactorGraph, DefaultDecisionTree) { ordering += X(0); ordering += X(1); - HybridGaussianFactorGraph linearized = fg.linearize(initialEstimate); + HybridGaussianFactorGraph linearized = *fg.linearize(initialEstimate); gtsam::HybridBayesNet::shared_ptr hybridBayesNet; gtsam::HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp new file mode 100644 index 000000000..fbb114ef3 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -0,0 +1,586 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridNonlinearISAM.cpp + * @brief Unit tests for nonlinear incremental inference + * @author Varun Agrawal, Fan Jiang, Frank Dellaert + * @date Jan 2021 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "Switching.h" + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::L; +using symbol_shorthand::M; +using symbol_shorthand::W; +using symbol_shorthand::X; +using symbol_shorthand::Y; +using symbol_shorthand::Z; + +/* ****************************************************************************/ +// Test if we can perform elimination incrementally. +TEST(HybridNonlinearISAM, IncrementalElimination) { + Switching switching(3); + HybridNonlinearISAM isam; + HybridNonlinearFactorGraph graph1; + Values initial; + + // Create initial factor graph + // * * * + // | | | + // X1 -*- X2 -*- X3 + // \*-M1-*/ + graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X1) + graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X1, X2 | M1) + graph1.push_back(switching.nonlinearFactorGraph.at(2)); // P(X2, X3 | M2) + graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M1) + + initial.insert(X(1), 1); + initial.insert(X(2), 2); + initial.insert(X(3), 3); + + // Run update step + isam.update(graph1, initial); + + // Check that after update we have 3 hybrid Bayes net nodes: + // P(X1 | X2, M1) and P(X2, X3 | M1, M2), P(M1, M2) + HybridGaussianISAM bayesTree = isam.bayesTree(); + EXPECT_LONGS_EQUAL(3, bayesTree.size()); + EXPECT(bayesTree[X(1)]->conditional()->frontals() == KeyVector{X(1)}); + EXPECT(bayesTree[X(1)]->conditional()->parents() == KeyVector({X(2), M(1)})); + EXPECT(bayesTree[X(2)]->conditional()->frontals() == KeyVector({X(2), X(3)})); + EXPECT(bayesTree[X(2)]->conditional()->parents() == KeyVector({M(1), M(2)})); + + /********************************************************/ + // New factor graph for incremental update. + HybridNonlinearFactorGraph graph2; + initial = Values(); + + graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X2) + graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X3) + graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M1, M2) + + isam.update(graph2, initial); + + bayesTree = isam.bayesTree(); + // Check that after the second update we have + // 1 additional hybrid Bayes net node: + // P(X2, X3 | M1, M2) + EXPECT_LONGS_EQUAL(3, bayesTree.size()); + EXPECT(bayesTree[X(3)]->conditional()->frontals() == KeyVector({X(2), X(3)})); + EXPECT(bayesTree[X(3)]->conditional()->parents() == KeyVector({M(1), M(2)})); +} + +/* ****************************************************************************/ +// Test if we can incrementally do the inference +TEST(HybridNonlinearISAM, IncrementalInference) { + Switching switching(3); + HybridNonlinearISAM isam; + HybridNonlinearFactorGraph graph1; + Values initial; + + // Create initial factor graph + // * * * + // | | | + // X1 -*- X2 -*- X3 + // | | + // *-M1 - * - M2 + graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X1) + graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X1, X2 | M1) + graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X2) + graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M1) + + initial.insert(X(1), 1); + initial.insert(X(2), 2); + + // Run update step + isam.update(graph1, initial); + HybridGaussianISAM bayesTree = isam.bayesTree(); + + auto discreteConditional_m1 = + bayesTree[M(1)]->conditional()->asDiscreteConditional(); + EXPECT(discreteConditional_m1->keys() == KeyVector({M(1)})); + + /********************************************************/ + // New factor graph for incremental update. + HybridNonlinearFactorGraph graph2; + initial = Values(); + + initial.insert(X(3), 3); + + graph2.push_back(switching.nonlinearFactorGraph.at(2)); // P(X2, X3 | M2) + graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X3) + graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M1, M2) + + isam.update(graph2, initial); + bayesTree = isam.bayesTree(); + + /********************************************************/ + // Run batch elimination so we can compare results. + Ordering ordering; + ordering += X(1); + ordering += X(2); + ordering += X(3); + + // Now we calculate the actual factors using full elimination + HybridBayesTree::shared_ptr expectedHybridBayesTree; + HybridGaussianFactorGraph::shared_ptr expectedRemainingGraph; + std::tie(expectedHybridBayesTree, expectedRemainingGraph) = + switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering); + + // The densities on X(1) should be the same + auto x1_conditional = dynamic_pointer_cast( + bayesTree[X(1)]->conditional()->inner()); + auto actual_x1_conditional = dynamic_pointer_cast( + (*expectedHybridBayesTree)[X(1)]->conditional()->inner()); + EXPECT(assert_equal(*x1_conditional, *actual_x1_conditional)); + + // The densities on X(2) should be the same + auto x2_conditional = dynamic_pointer_cast( + bayesTree[X(2)]->conditional()->inner()); + auto actual_x2_conditional = dynamic_pointer_cast( + (*expectedHybridBayesTree)[X(2)]->conditional()->inner()); + EXPECT(assert_equal(*x2_conditional, *actual_x2_conditional)); + + // The densities on X(3) should be the same + auto x3_conditional = dynamic_pointer_cast( + bayesTree[X(3)]->conditional()->inner()); + auto actual_x3_conditional = dynamic_pointer_cast( + (*expectedHybridBayesTree)[X(2)]->conditional()->inner()); + EXPECT(assert_equal(*x3_conditional, *actual_x3_conditional)); + + // We only perform manual continuous elimination for 0,0. + // The other discrete probabilities on M(2) are calculated the same way + Ordering discrete_ordering; + discrete_ordering += M(1); + discrete_ordering += M(2); + HybridBayesTree::shared_ptr discreteBayesTree = + expectedRemainingGraph->eliminateMultifrontal(discrete_ordering); + + DiscreteValues m00; + m00[M(1)] = 0, m00[M(2)] = 0; + DiscreteConditional decisionTree = + *(*discreteBayesTree)[M(2)]->conditional()->asDiscreteConditional(); + double m00_prob = decisionTree(m00); + + auto discreteConditional = + bayesTree[M(2)]->conditional()->asDiscreteConditional(); + + // Test if the probability values are as expected with regression tests. + DiscreteValues assignment; + EXPECT(assert_equal(m00_prob, 0.0619233, 1e-5)); + assignment[M(1)] = 0; + assignment[M(2)] = 0; + EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5)); + assignment[M(1)] = 1; + assignment[M(2)] = 0; + EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5)); + assignment[M(1)] = 0; + assignment[M(2)] = 1; + EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5)); + assignment[M(1)] = 1; + assignment[M(2)] = 1; + EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5)); + + // Check if the clique conditional generated from incremental elimination + // matches that of batch elimination. + auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal(); + auto expectedConditional = dynamic_pointer_cast( + (*expectedChordal)[M(2)]->conditional()->inner()); + auto actualConditional = dynamic_pointer_cast( + bayesTree[M(2)]->conditional()->inner()); + EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6)); +} + +/* ****************************************************************************/ +// Test if we can approximately do the inference +TEST(HybridNonlinearISAM, Approx_inference) { + Switching switching(4); + HybridNonlinearISAM incrementalHybrid; + HybridNonlinearFactorGraph graph1; + Values initial; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(switching.nonlinearFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(1), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(switching.nonlinearFactorGraph.at(0)); + for (size_t i = 4; i <= 7; i++) { + graph1.push_back(switching.nonlinearFactorGraph.at(i)); + initial.insert(X(i - 3), i - 3); + } + + // Create ordering. + Ordering ordering; + for (size_t j = 1; j <= 4; j++) { + ordering += X(j); + } + + // Now we calculate the actual factors using full elimination + HybridBayesTree::shared_ptr unprunedHybridBayesTree; + HybridGaussianFactorGraph::shared_ptr unprunedRemainingGraph; + std::tie(unprunedHybridBayesTree, unprunedRemainingGraph) = + switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering); + + size_t maxNrLeaves = 5; + incrementalHybrid.update(graph1, initial); + HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); + + bayesTree.prune(maxNrLeaves); + + /* + unpruned factor is: + Choice(m3) + 0 Choice(m2) + 0 0 Choice(m1) + 0 0 0 Leaf 0.11267528 + 0 0 1 Leaf 0.18576102 + 0 1 Choice(m1) + 0 1 0 Leaf 0.18754662 + 0 1 1 Leaf 0.30623871 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Leaf 0.18576102 + 1 0 1 Leaf 0.30622428 + 1 1 Choice(m1) + 1 1 0 Leaf 0.30623871 + 1 1 1 Leaf 0.5 + + pruned factors is: + Choice(m3) + 0 Choice(m2) + 0 0 Leaf 0 + 0 1 Choice(m1) + 0 1 0 Leaf 0.18754662 + 0 1 1 Leaf 0.30623871 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Leaf 0 + 1 0 1 Leaf 0.30622428 + 1 1 Choice(m1) + 1 1 0 Leaf 0.30623871 + 1 1 1 Leaf 0.5 + */ + + auto discreteConditional_m1 = *dynamic_pointer_cast( + bayesTree[M(1)]->conditional()->inner()); + EXPECT(discreteConditional_m1.keys() == KeyVector({M(1), M(2), M(3)})); + + // Get the number of elements which are greater than 0. + auto count = [](const double &value, int count) { + return value > 0 ? count + 1 : count; + }; + // Check that the number of leaves after pruning is 5. + EXPECT_LONGS_EQUAL(5, discreteConditional_m1.fold(count, 0)); + + // Check that the hybrid nodes of the bayes net match those of the pre-pruning + // bayes net, at the same positions. + auto &unprunedLastDensity = *dynamic_pointer_cast( + unprunedHybridBayesTree->clique(X(4))->conditional()->inner()); + auto &lastDensity = *dynamic_pointer_cast( + bayesTree[X(4)]->conditional()->inner()); + + std::vector> assignments = + discreteConditional_m1.enumerate(); + // Loop over all assignments and check the pruned components + for (auto &&av : assignments) { + const DiscreteValues &assignment = av.first; + const double value = av.second; + + if (value == 0.0) { + EXPECT(lastDensity(assignment) == nullptr); + } else { + CHECK(lastDensity(assignment)); + EXPECT(assert_equal(*unprunedLastDensity(assignment), + *lastDensity(assignment))); + } + } +} + +/* ****************************************************************************/ +// Test approximate inference with an additional pruning step. +TEST(HybridNonlinearISAM, Incremental_approximate) { + Switching switching(5); + HybridNonlinearISAM incrementalHybrid; + HybridNonlinearFactorGraph graph1; + Values initial; + + /***** Run Round 1 *****/ + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(switching.nonlinearFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(1), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(switching.nonlinearFactorGraph.at(0)); + initial.insert(X(1), 1); + for (size_t i = 5; i <= 7; i++) { + graph1.push_back(switching.nonlinearFactorGraph.at(i)); + initial.insert(X(i - 3), i - 3); + } + + // Run update with pruning + size_t maxComponents = 5; + incrementalHybrid.update(graph1, initial); + HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); + + bayesTree.prune(maxComponents); + + // Check if we have a bayes tree with 4 hybrid nodes, + // each with 2, 4, 8, and 5 (pruned) leaves respetively. + EXPECT_LONGS_EQUAL(4, bayesTree.size()); + EXPECT_LONGS_EQUAL( + 2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents()); + EXPECT_LONGS_EQUAL( + 3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); + EXPECT_LONGS_EQUAL( + 5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents()); + EXPECT_LONGS_EQUAL( + 5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents()); + + /***** Run Round 2 *****/ + HybridGaussianFactorGraph graph2; + graph2.push_back(switching.nonlinearFactorGraph.at(4)); // x4-x5 + graph2.push_back(switching.nonlinearFactorGraph.at(8)); // x5 measurement + initial = Values(); + initial.insert(X(5), 5); + + // Run update with pruning a second time. + incrementalHybrid.update(graph2, initial); + bayesTree = incrementalHybrid.bayesTree(); + + bayesTree.prune(maxComponents); + + // Check if we have a bayes tree with pruned hybrid nodes, + // with 5 (pruned) leaves. + CHECK_EQUAL(5, bayesTree.size()); + EXPECT_LONGS_EQUAL( + 5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents()); + EXPECT_LONGS_EQUAL( + 5, bayesTree[X(5)]->conditional()->asMixture()->nrComponents()); +} + +/* ************************************************************************/ +// A GTSAM-only test for running inference on a single-legged robot. +// The leg links are represented by the chain X-Y-Z-W, where X is the base and +// W is the foot. +// We use BetweenFactor as constraints between each of the poses. +TEST(HybridNonlinearISAM, NonTrivial) { + /*************** Run Round 1 ***************/ + HybridNonlinearFactorGraph fg; + HybridNonlinearISAM inc; + + // Add a prior on pose x1 at the origin. + // A prior factor consists of a mean and + // a noise model (covariance matrix) + Pose2 prior(0.0, 0.0, 0.0); // prior mean is at origin + auto priorNoise = noiseModel::Diagonal::Sigmas( + Vector3(0.3, 0.3, 0.1)); // 30cm std on x,y, 0.1 rad on theta + fg.emplace_nonlinear>(X(0), prior, priorNoise); + + // create a noise model for the landmark measurements + auto poseNoise = noiseModel::Isotropic::Sigma(3, 0.1); + + // We model a robot's single leg as X - Y - Z - W + // where X is the base link and W is the foot link. + + // Add connecting poses similar to PoseFactors in GTD + fg.emplace_nonlinear>(X(0), Y(0), Pose2(0, 1.0, 0), + poseNoise); + fg.emplace_nonlinear>(Y(0), Z(0), Pose2(0, 1.0, 0), + poseNoise); + fg.emplace_nonlinear>(Z(0), W(0), Pose2(0, 1.0, 0), + poseNoise); + + // Create initial estimate + Values initial; + initial.insert(X(0), Pose2(0.0, 0.0, 0.0)); + initial.insert(Y(0), Pose2(0.0, 1.0, 0.0)); + initial.insert(Z(0), Pose2(0.0, 2.0, 0.0)); + initial.insert(W(0), Pose2(0.0, 3.0, 0.0)); + + // Don't run update now since we don't have discrete variables involved. + + using PlanarMotionModel = BetweenFactor; + + /*************** Run Round 2 ***************/ + // Add odometry factor with discrete modes. + Pose2 odometry(1.0, 0.0, 0.0); + KeyVector contKeys = {W(0), W(1)}; + auto noise_model = noiseModel::Isotropic::Sigma(3, 1.0); + auto still = boost::make_shared(W(0), W(1), Pose2(0, 0, 0), + noise_model), + moving = boost::make_shared(W(0), W(1), odometry, + noise_model); + std::vector components = {moving, still}; + auto mixtureFactor = boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components); + fg.push_back(mixtureFactor); + + // Add equivalent of ImuFactor + fg.emplace_nonlinear>(X(0), X(1), Pose2(1.0, 0.0, 0), + poseNoise); + // PoseFactors-like at k=1 + fg.emplace_nonlinear>(X(1), Y(1), Pose2(0, 1, 0), + poseNoise); + fg.emplace_nonlinear>(Y(1), Z(1), Pose2(0, 1, 0), + poseNoise); + fg.emplace_nonlinear>(Z(1), W(1), Pose2(-1, 1, 0), + poseNoise); + + initial.insert(X(1), Pose2(1.0, 0.0, 0.0)); + initial.insert(Y(1), Pose2(1.0, 1.0, 0.0)); + initial.insert(Z(1), Pose2(1.0, 2.0, 0.0)); + // The leg link did not move so we set the expected pose accordingly. + initial.insert(W(1), Pose2(0.0, 3.0, 0.0)); + + // Update without pruning + // The result is a HybridBayesNet with 1 discrete variable M(1). + // P(X | measurements) = P(W0|Z0, W1, M1) P(Z0|Y0, W1, M1) P(Y0|X0, W1, M1) + // P(X0 | X1, W1, M1) P(W1|Z1, X1, M1) P(Z1|Y1, X1, M1) + // P(Y1 | X1, M1)P(X1 | M1)P(M1) + // The MHS tree is a 1 level tree for time indices (1,) with 2 leaves. + inc.update(fg, initial); + + fg = HybridNonlinearFactorGraph(); + initial = Values(); + + /*************** Run Round 3 ***************/ + // Add odometry factor with discrete modes. + contKeys = {W(1), W(2)}; + still = boost::make_shared(W(1), W(2), Pose2(0, 0, 0), + noise_model); + moving = + boost::make_shared(W(1), W(2), odometry, noise_model); + components = {moving, still}; + mixtureFactor = boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components); + fg.push_back(mixtureFactor); + + // Add equivalent of ImuFactor + fg.emplace_nonlinear>(X(1), X(2), Pose2(1.0, 0.0, 0), + poseNoise); + // PoseFactors-like at k=1 + fg.emplace_nonlinear>(X(2), Y(2), Pose2(0, 1, 0), + poseNoise); + fg.emplace_nonlinear>(Y(2), Z(2), Pose2(0, 1, 0), + poseNoise); + fg.emplace_nonlinear>(Z(2), W(2), Pose2(-2, 1, 0), + poseNoise); + + initial.insert(X(2), Pose2(2.0, 0.0, 0.0)); + initial.insert(Y(2), Pose2(2.0, 1.0, 0.0)); + initial.insert(Z(2), Pose2(2.0, 2.0, 0.0)); + initial.insert(W(2), Pose2(0.0, 3.0, 0.0)); + + // Now we prune! + // P(X | measurements) = P(W0|Z0, W1, M1) P(Z0|Y0, W1, M1) P(Y0|X0, W1, M1) + // P(X0 | X1, W1, M1) P(W1|W2, Z1, X1, M1, M2) + // P(Z1| W2, Y1, X1, M1, M2) P(Y1 | W2, X1, M1, M2) + // P(X1 | W2, X2, M1, M2) P(W2|Z2, X2, M1, M2) + // P(Z2|Y2, X2, M1, M2) P(Y2 | X2, M1, M2) + // P(X2 | M1, M2) P(M1, M2) + // The MHS at this point should be a 2 level tree on (1, 2). + // 1 has 2 choices, and 2 has 4 choices. + inc.update(fg, initial); + inc.prune(2); + + fg = HybridNonlinearFactorGraph(); + initial = Values(); + + /*************** Run Round 4 ***************/ + // Add odometry factor with discrete modes. + contKeys = {W(2), W(3)}; + still = boost::make_shared(W(2), W(3), Pose2(0, 0, 0), + noise_model); + moving = + boost::make_shared(W(2), W(3), odometry, noise_model); + components = {moving, still}; + mixtureFactor = boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components); + fg.push_back(mixtureFactor); + + // Add equivalent of ImuFactor + fg.emplace_nonlinear>(X(2), X(3), Pose2(1.0, 0.0, 0), + poseNoise); + // PoseFactors-like at k=3 + fg.emplace_nonlinear>(X(3), Y(3), Pose2(0, 1, 0), + poseNoise); + fg.emplace_nonlinear>(Y(3), Z(3), Pose2(0, 1, 0), + poseNoise); + fg.emplace_nonlinear>(Z(3), W(3), Pose2(-3, 1, 0), + poseNoise); + + initial.insert(X(3), Pose2(3.0, 0.0, 0.0)); + initial.insert(Y(3), Pose2(3.0, 1.0, 0.0)); + initial.insert(Z(3), Pose2(3.0, 2.0, 0.0)); + initial.insert(W(3), Pose2(0.0, 3.0, 0.0)); + + // Keep pruning! + inc.update(fg, initial); + inc.prune(3); + + fg = HybridNonlinearFactorGraph(); + initial = Values(); + + HybridGaussianISAM bayesTree = inc.bayesTree(); + + // The final discrete graph should not be empty since we have eliminated + // all continuous variables. + auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional(); + EXPECT_LONGS_EQUAL(3, discreteTree->size()); + + // Test if the optimal discrete mode assignment is (1, 1, 1). + DiscreteFactorGraph discreteGraph; + discreteGraph.push_back(discreteTree); + DiscreteValues optimal_assignment = discreteGraph.optimize(); + + DiscreteValues expected_assignment; + expected_assignment[M(1)] = 1; + expected_assignment[M(2)] = 1; + expected_assignment[M(3)] = 1; + + EXPECT(assert_equal(expected_assignment, optimal_assignment)); + + // Test if pruning worked correctly by checking that + // we only have 3 leaves in the last node. + auto lastConditional = bayesTree[X(3)]->conditional()->asMixture(); + EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents()); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 3d0c4ecff..b9bffe13f 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -33,7 +33,6 @@ namespace gtsam { // Forward declarations template class FactorGraph; template class EliminatableClusterTree; - class HybridBayesTreeClique; /* ************************************************************************* */ /** clique statistics */ diff --git a/gtsam/inference/JunctionTree-inst.h b/gtsam/inference/JunctionTree-inst.h index 04472f7e3..1c68aa0da 100644 --- a/gtsam/inference/JunctionTree-inst.h +++ b/gtsam/inference/JunctionTree-inst.h @@ -33,7 +33,7 @@ struct ConstructorTraversalData { typedef typename JunctionTree::sharedNode sharedNode; ConstructorTraversalData* const parentData; - sharedNode myJTNode; + sharedNode junctionTreeNode; FastVector childSymbolicConditionals; FastVector childSymbolicFactors; @@ -53,8 +53,9 @@ struct ConstructorTraversalData { // a traversal data structure with its own JT node, and create a child // pointer in its parent. ConstructorTraversalData myData = ConstructorTraversalData(&parentData); - myData.myJTNode = boost::make_shared(node->key, node->factors); - parentData.myJTNode->addChild(myData.myJTNode); + myData.junctionTreeNode = + boost::make_shared(node->key, node->factors); + parentData.junctionTreeNode->addChild(myData.junctionTreeNode); return myData; } @@ -91,7 +92,7 @@ struct ConstructorTraversalData { myData.parentData->childSymbolicConditionals.push_back(myConditional); myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor); - sharedNode node = myData.myJTNode; + sharedNode node = myData.junctionTreeNode; const FastVector& childConditionals = myData.childSymbolicConditionals; node->problemSize_ = (int) (myConditional->size() * symbolicFactors.size()); @@ -138,14 +139,14 @@ JunctionTree::JunctionTree( typedef typename EliminationTree::Node ETreeNode; typedef ConstructorTraversalData Data; Data rootData(0); - rootData.myJTNode = boost::make_shared(); // Make a dummy node to gather - // the junction tree roots + // Make a dummy node to gather the junction tree roots + rootData.junctionTreeNode = boost::make_shared(); treeTraversal::DepthFirstForest(eliminationTree, rootData, Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPostAlg2); // Assign roots from the dummy node - this->addChildrenAsRoots(rootData.myJTNode); + this->addChildrenAsRoots(rootData.junctionTreeNode); // Transfer remaining factors from elimination tree Base::remainingFactors_ = eliminationTree.remainingFactors(); diff --git a/gtsam/linear/tests/testSerializationLinear.cpp b/gtsam/linear/tests/testSerializationLinear.cpp index 881b2830e..ee21de364 100644 --- a/gtsam/linear/tests/testSerializationLinear.cpp +++ b/gtsam/linear/tests/testSerializationLinear.cpp @@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) { EXPECT(equalsBinary(graph)); } +/* ****************************************************************************/ +TEST(Serialization, gaussian_bayes_net) { + // Create an arbitrary Bayes Net + GaussianBayesNet gbn; + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3, + (Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4, + (Matrix2() << 11.0, 12.0, 13.0, 14.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(), + 2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4, + (Matrix2() << 25.0, 26.0, 27.0, 28.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(), + 3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(), + 4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished())); + + std::string serialized = serialize(gbn); + GaussianBayesNet actual; + deserialize(serialized, actual); + EXPECT(assert_equal(gbn, actual)); +} + /* ************************************************************************* */ TEST (Serialization, gaussian_bayes_tree) { const Key x1=1, x2=2, x3=3, x4=4; diff --git a/gtsam/sfm/DsfTrackGenerator.cpp b/gtsam/sfm/DsfTrackGenerator.cpp new file mode 100644 index 000000000..e82880193 --- /dev/null +++ b/gtsam/sfm/DsfTrackGenerator.cpp @@ -0,0 +1,136 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DsfTrackGenerator.cpp + * @date October 2022 + * @author John Lambert + * @brief Identifies connected components in the keypoint matches graph. + */ + +#include + +#include +#include + +namespace gtsam { + +namespace gtsfm { + +typedef DSFMap DSFMapIndexPair; + +/// Generate the DSF to form tracks. +static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) { + DSFMapIndexPair dsf; + + for (const auto& kv : matches) { + const auto pair_indices = kv.first; + const auto corr_indices = kv.second; + + // Image pair is (i1,i2). + size_t i1 = pair_indices.first; + size_t i2 = pair_indices.second; + for (size_t k = 0; k < corr_indices.rows(); k++) { + // Measurement indices are found in a single matrix row, as (k1,k2). + size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1); + // Unique key for DSF is (i,k), representing keypoint index in an image. + dsf.merge({i1, k1}, {i2, k2}); + } + } + + return dsf; +} + +/// Generate a single track from a set of index pairs +static SfmTrack2d trackFromIndexPairs(const std::set& index_pair_set, + const KeypointsVector& keypoints) { + // Initialize track from measurements. + SfmTrack2d track2d; + + for (const auto& index_pair : index_pair_set) { + // Camera index is represented by i, and measurement index is + // represented by k. + size_t i = index_pair.i(); + size_t k = index_pair.j(); + // Add measurement to this track. + track2d.addMeasurement(i, keypoints[i].coordinates.row(k)); + } + + return track2d; +} + +/// Generate tracks from the DSF. +static std::vector tracksFromDSF(const DSFMapIndexPair& dsf, + const KeypointsVector& keypoints) { + const std::map > key_sets = dsf.sets(); + + // Return immediately if no sets were found. + if (key_sets.empty()) return {}; + + // Create a list of tracks. + // Each track will be represented as a list of (camera_idx, measurements). + std::vector tracks2d; + for (const auto& kv : key_sets) { + // Initialize track from measurements. + SfmTrack2d track2d = trackFromIndexPairs(kv.second, keypoints); + tracks2d.emplace_back(track2d); + } + return tracks2d; +} + +/** + * @brief Creates a list of tracks from 2d point correspondences. + * + * Creates a disjoint-set forest (DSF) and 2d tracks from pairwise matches. + * We create a singleton for union-find set elements from camera index of a + * detection and the index of that detection in that camera's keypoint list, + * i.e. (i,k). + * + * @param Map from (i1,i2) image pair indices to (K,2) matrix, for K + * correspondence indices, from each image. + * @param Length-N list of keypoints, for N images/cameras. + */ +std::vector tracksFromPairwiseMatches( + const MatchIndicesMap& matches, const KeypointsVector& keypoints, + bool verbose) { + // Generate the DSF to form tracks. + if (verbose) std::cout << "[SfmTrack2d] Starting Union-Find..." << std::endl; + DSFMapIndexPair dsf = generateDSF(matches); + if (verbose) std::cout << "[SfmTrack2d] Union-Find Complete" << std::endl; + + std::vector tracks2d = tracksFromDSF(dsf, keypoints); + + // Filter out erroneous tracks that had repeated measurements within the + // same image. This is an expected result from an incorrect correspondence + // slipping through. + std::vector validTracks; + std::copy_if( + tracks2d.begin(), tracks2d.end(), std::back_inserter(validTracks), + [](const SfmTrack2d& track2d) { return track2d.hasUniqueCameras(); }); + + if (verbose) { + size_t erroneous_track_count = tracks2d.size() - validTracks.size(); + double erroneous_percentage = static_cast(erroneous_track_count) / + static_cast(tracks2d.size()) * 100; + + std::cout << std::fixed << std::setprecision(2); + std::cout << "DSF Union-Find: " << erroneous_percentage; + std::cout << "% of tracks discarded from multiple obs. in a single image." + << std::endl; + } + + // TODO(johnwlambert): return the Transitivity failure percentage here. + return tracks2d; +} + +} // namespace gtsfm + +} // namespace gtsam diff --git a/gtsam/sfm/DsfTrackGenerator.h b/gtsam/sfm/DsfTrackGenerator.h new file mode 100644 index 000000000..14ec2302d --- /dev/null +++ b/gtsam/sfm/DsfTrackGenerator.h @@ -0,0 +1,78 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DsfTrackGenerator.h + * @date July 2022 + * @author John Lambert + * @brief Identifies connected components in the keypoint matches graph. + */ + +#pragma once +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + +namespace gtsfm { + +typedef Eigen::MatrixX2i CorrespondenceIndices; // N x 2 array + +// Output of detections in an image. +// Coordinate system convention: +// 1. The x coordinate denotes the horizontal direction (+ve direction towards +// the right). +// 2. The y coordinate denotes the vertical direction (+ve direction downwards). +// 3. Origin is at the top left corner of the image. +struct Keypoints { + // The (x, y) coordinates of the features, of shape Nx2. + Eigen::MatrixX2d coordinates; + + // Optional scale of the detections, of shape N. + // Note: gtsam::Vector is typedef'd for Eigen::VectorXd. + boost::optional scales; + + /// Optional confidences/responses for each detection, of shape N. + boost::optional responses; + + Keypoints(const Eigen::MatrixX2d& coordinates) + : coordinates(coordinates){}; // boost::none +}; + +using KeypointsVector = std::vector; +// Mapping from each image pair to (N,2) array representing indices of matching +// keypoints. +using MatchIndicesMap = std::map; + +/** + * @brief Creates a list of tracks from 2d point correspondences. + * + * Creates a disjoint-set forest (DSF) and 2d tracks from pairwise matches. + * We create a singleton for union-find set elements from camera index of a + * detection and the index of that detection in that camera's keypoint list, + * i.e. (i,k). + * + * @param Map from (i1,i2) image pair indices to (K,2) matrix, for K + * correspondence indices, from each image. + * @param Length-N list of keypoints, for N images/cameras. + */ +std::vector tracksFromPairwiseMatches( + const MatchIndicesMap& matches, const KeypointsVector& keypoints, + bool verbose = false); + +} // namespace gtsfm + +} // namespace gtsam diff --git a/gtsam/sfm/SfmTrack.h b/gtsam/sfm/SfmTrack.h index ab084aca9..c75199374 100644 --- a/gtsam/sfm/SfmTrack.h +++ b/gtsam/sfm/SfmTrack.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -35,28 +36,26 @@ typedef std::pair SfmMeasurement; typedef std::pair SiftIndex; /** - * @brief An SfmTrack stores SfM measurements grouped in a track - * @ingroup sfm + * @brief Track containing 2D measurements associated with a single 3D point. + * Note: Equivalent to gtsam.SfmTrack, but without the 3d measurement. + * This class holds data temporarily before 3D point is initialized. */ -struct GTSAM_EXPORT SfmTrack { - Point3 p; ///< 3D position of the point - float r, g, b; ///< RGB color of the 3D point - +struct GTSAM_EXPORT SfmTrack2d { /// The 2D image projections (id,(u,v)) std::vector measurements; - /// The feature descriptors + /// The feature descriptors (optional) std::vector siftIndices; /// @name Constructors /// @{ - explicit SfmTrack(float r = 0, float g = 0, float b = 0) - : p(0, 0, 0), r(r), g(g), b(b) {} + // Default constructor. + SfmTrack2d() = default; - explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0, - float b = 0) - : p(pt), r(r), g(g), b(b) {} + // Constructor from measurements. + explicit SfmTrack2d(const std::vector& measurements) + : measurements(measurements) {} /// @} /// @name Standard Interface @@ -78,6 +77,70 @@ struct GTSAM_EXPORT SfmTrack { /// Get the SIFT feature index corresponding to the measurement at `idx` const SiftIndex& siftIndex(size_t idx) const { return siftIndices[idx]; } + /** + * @brief Check that no two measurements are from the same camera. + * @returns boolean result of the validation. + */ + bool hasUniqueCameras() const { + std::vector track_cam_indices; + for (auto& measurement : measurements) { + track_cam_indices.emplace_back(measurement.first); + } + auto i = + std::adjacent_find(track_cam_indices.begin(), track_cam_indices.end()); + bool all_cameras_unique = (i == track_cam_indices.end()); + return all_cameras_unique; + } + + /// @} + /// @name Vectorized Interface + /// @{ + + /// @brief Return the measurements as a 2D matrix + Eigen::MatrixX2d measurementMatrix() const { + Eigen::MatrixX2d m(numberMeasurements(), 2); + for (size_t i = 0; i < numberMeasurements(); i++) { + m.row(i) = measurement(i).second; + } + return m; + } + + /// @brief Return the camera indices of the measurements + Eigen::VectorXi indexVector() const { + Eigen::VectorXi v(numberMeasurements()); + for (size_t i = 0; i < numberMeasurements(); i++) { + v(i) = measurement(i).first; + } + return v; + } + + /// @} +}; + +using SfmTrack2dVector = std::vector; + +/** + * @brief An SfmTrack stores SfM measurements grouped in a track + * @addtogroup sfm + */ +struct GTSAM_EXPORT SfmTrack : SfmTrack2d { + Point3 p; ///< 3D position of the point + float r, g, b; ///< RGB color of the 3D point + + /// @name Constructors + /// @{ + + explicit SfmTrack(float r = 0, float g = 0, float b = 0) + : p(0, 0, 0), r(r), g(g), b(b) {} + + explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0, + float b = 0) + : p(pt), r(r), g(g), b(b) {} + + /// @} + /// @name Standard Interface + /// @{ + /// Get 3D point const Point3& point3() const { return p; } diff --git a/gtsam/sfm/sfm.i b/gtsam/sfm/sfm.i index 83bd07b13..26dc20c3e 100644 --- a/gtsam/sfm/sfm.i +++ b/gtsam/sfm/sfm.i @@ -4,10 +4,23 @@ namespace gtsam { -#include -#include #include -class SfmTrack { +class SfmTrack2d { + std::vector> measurements; + + SfmTrack2d(); + SfmTrack2d(const std::vector& measurements); + size_t numberMeasurements() const; + pair measurement(size_t idx) const; + pair siftIndex(size_t idx) const; + void addMeasurement(size_t idx, const gtsam::Point2& m); + gtsam::SfmMeasurement measurement(size_t idx) const; + bool hasUniqueCameras() const; + Eigen::MatrixX2d measurementMatrix() const; + Eigen::VectorXi indexVector() const; +}; + +virtual class SfmTrack : gtsam::SfmTrack2d { SfmTrack(); SfmTrack(const gtsam::Point3& pt); const Point3& point3() const; @@ -18,13 +31,6 @@ class SfmTrack { double g; double b; - std::vector> measurements; - - size_t numberMeasurements() const; - pair measurement(size_t idx) const; - pair siftIndex(size_t idx) const; - void addMeasurement(size_t idx, const gtsam::Point2& m); - // enabling serialization functionality void serialize() const; @@ -32,6 +38,8 @@ class SfmTrack { bool equals(const gtsam::SfmTrack& expected, double tol) const; }; +#include +#include #include class SfmData { SfmData(); @@ -115,7 +123,7 @@ class BinaryMeasurementsRot3 { #include -// TODO(frank): copy/pasta below until we have integer template paremeters in +// TODO(frank): copy/pasta below until we have integer template parameters in // wrap! class ShonanAveragingParameters2 { @@ -310,4 +318,38 @@ class TranslationRecovery { const gtsam::BinaryMeasurementsUnit3& relativeTranslations) const; }; +namespace gtsfm { + +#include + +class MatchIndicesMap { + MatchIndicesMap(); + MatchIndicesMap(const gtsam::gtsfm::MatchIndicesMap& other); + + size_t size() const; + bool empty() const; + void clear(); + gtsam::gtsfm::CorrespondenceIndices at(const pair& keypair) const; +}; + +class Keypoints { + Keypoints(const Eigen::MatrixX2d& coordinates); + Eigen::MatrixX2d coordinates; +}; + +class KeypointsVector { + KeypointsVector(); + KeypointsVector(const gtsam::gtsfm::KeypointsVector& other); + void push_back(const gtsam::gtsfm::Keypoints& keypoints); + size_t size() const; + bool empty() const; + void clear(); + gtsam::gtsfm::Keypoints at(const size_t& index) const; +}; + +gtsam::SfmTrack2dVector tracksFromPairwiseMatches( + const gtsam::gtsfm::MatchIndicesMap& matches_dict, + const gtsam::gtsfm::KeypointsVector& keypoints_list, bool verbose = false); +} // namespace gtsfm + } // namespace gtsam diff --git a/gtsam/sfm/tests/testSfmTrack.cpp b/gtsam/sfm/tests/testSfmTrack.cpp new file mode 100644 index 000000000..1b8c6bd9a --- /dev/null +++ b/gtsam/sfm/tests/testSfmTrack.cpp @@ -0,0 +1,53 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TestSfmTrack.cpp + * @date October 2022 + * @author Frank Dellaert + * @brief tests for SfmTrack class + */ + +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +TEST(SfmTrack2d, defaultConstructor) { + SfmTrack2d track; + EXPECT_LONGS_EQUAL(0, track.measurements.size()); + EXPECT_LONGS_EQUAL(0, track.siftIndices.size()); +} + +/* ************************************************************************* */ +TEST(SfmTrack2d, measurementConstructor) { + SfmTrack2d track({{0, Point2(1, 2)}, {1, Point2(3, 4)}}); + EXPECT_LONGS_EQUAL(2, track.measurements.size()); + EXPECT_LONGS_EQUAL(0, track.siftIndices.size()); +} + +/* ************************************************************************* */ +TEST(SfmTrack, construction) { + SfmTrack track(Point3(1, 2, 3), 4, 5, 6); + EXPECT(assert_equal(Point3(1, 2, 3), track.point3())); + EXPECT(assert_equal(Point3(4, 5, 6), track.rgb())); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 4457678d2..79a27f17f 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -51,7 +51,10 @@ set(ignore gtsam::BinaryMeasurementsUnit3 gtsam::BinaryMeasurementsRot3 gtsam::DiscreteKey - gtsam::KeyPairDoubleMap) + gtsam::KeyPairDoubleMap + gtsam::gtsfm::MatchIndicesMap + gtsam::gtsfm::KeypointsVector + gtsam::gtsfm::SfmTrack2dVector) set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i @@ -148,8 +151,12 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) gtsam::CameraSetCal3Bundler gtsam::CameraSetCal3Unified gtsam::CameraSetCal3Fisheye - gtsam::KeyPairDoubleMap) - + gtsam::KeyPairDoubleMap + gtsam::gtsfm::MatchIndicesMap + gtsam::gtsfm::KeypointsVector + gtsam::gtsfm::SfmTrack2dVector) + + pybind_wrap(${GTSAM_PYTHON_UNSTABLE_TARGET} # target ${PROJECT_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i # interface_header "gtsam_unstable.cpp" # generated_cpp diff --git a/python/gtsam/gtsfm.py b/python/gtsam/gtsfm.py new file mode 100644 index 000000000..afa709083 --- /dev/null +++ b/python/gtsam/gtsfm.py @@ -0,0 +1,4 @@ +# This trick is to allow direct import of sub-modules +# without this, we can only do `from gtsam.gtsam.gtsfm import X` +# with this trick, we can do `from gtsam.gtsfm import X` +from .gtsam.gtsfm import * \ No newline at end of file diff --git a/python/gtsam/preamble/sfm.h b/python/gtsam/preamble/sfm.h index 8ff0ea82e..27a4e5de9 100644 --- a/python/gtsam/preamble/sfm.h +++ b/python/gtsam/preamble/sfm.h @@ -15,12 +15,13 @@ // #include #include -PYBIND11_MAKE_OPAQUE( - std::vector); - -PYBIND11_MAKE_OPAQUE( - std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE( std::vector>); PYBIND11_MAKE_OPAQUE( std::vector>); +PYBIND11_MAKE_OPAQUE( + std::vector); +PYBIND11_MAKE_OPAQUE(gtsam::gtsfm::MatchIndicesMap); \ No newline at end of file diff --git a/python/gtsam/specializations/sfm.h b/python/gtsam/specializations/sfm.h index 311b2c59b..c4817f555 100644 --- a/python/gtsam/specializations/sfm.h +++ b/python/gtsam/specializations/sfm.h @@ -18,16 +18,11 @@ py::bind_vector > >( py::bind_vector > >( m_, "BinaryMeasurementsRot3"); py::bind_map(m_, "KeyPairDoubleMap"); +py::bind_vector>(m_, "SfmTrack2dVector"); +py::bind_vector>(m_, "SfmTracks"); +py::bind_vector>(m_, "SfmCameras"); +py::bind_vector>>( + m_, "SfmMeasurementVector"); -py::bind_vector< - std::vector >( - m_, "SfmTracks"); - -py::bind_vector< - std::vector >( - m_, "SfmCameras"); - -py::bind_vector< - std::vector>>( - m_, "SfmMeasurementVector" - ); +py::bind_map(m_, "MatchIndicesMap"); +py::bind_vector>(m_, "KeypointsVector"); diff --git a/python/gtsam/tests/test_dsf_map.py b/python/gtsam/tests/test_DSFMap.py similarity index 88% rename from python/gtsam/tests/test_dsf_map.py rename to python/gtsam/tests/test_DSFMap.py index 6cae98ff5..f973f7c99 100644 --- a/python/gtsam/tests/test_dsf_map.py +++ b/python/gtsam/tests/test_DSFMap.py @@ -15,8 +15,7 @@ from __future__ import print_function import unittest from typing import Tuple -import gtsam -from gtsam import IndexPair +from gtsam import DSFMapIndexPair, IndexPair, IndexPairSetAsArray from gtsam.utils.test_case import GtsamTestCase @@ -29,10 +28,10 @@ class TestDSFMap(GtsamTestCase): def key(index_pair) -> Tuple[int, int]: return index_pair.i(), index_pair.j() - dsf = gtsam.DSFMapIndexPair() - pair1 = gtsam.IndexPair(1, 18) + dsf = DSFMapIndexPair() + pair1 = IndexPair(1, 18) self.assertEqual(key(dsf.find(pair1)), key(pair1)) - pair2 = gtsam.IndexPair(2, 2) + pair2 = IndexPair(2, 2) # testing the merge feature of dsf dsf.merge(pair1, pair2) @@ -45,7 +44,7 @@ class TestDSFMap(GtsamTestCase): k'th detected keypoint in image i. For the data below, merging such measurements into feature tracks across frames should create 2 distinct sets. """ - dsf = gtsam.DSFMapIndexPair() + dsf = DSFMapIndexPair() dsf.merge(IndexPair(0, 1), IndexPair(1, 2)) dsf.merge(IndexPair(0, 1), IndexPair(3, 4)) dsf.merge(IndexPair(4, 5), IndexPair(6, 8)) @@ -56,7 +55,7 @@ class TestDSFMap(GtsamTestCase): for i in sets: set_keys = [] s = sets[i] - for val in gtsam.IndexPairSetAsArray(s): + for val in IndexPairSetAsArray(s): set_keys.append((val.i(), val.j())) merged_sets.add(tuple(set_keys)) diff --git a/python/gtsam/tests/test_DsfTrackGenerator.py b/python/gtsam/tests/test_DsfTrackGenerator.py new file mode 100644 index 000000000..e600227c9 --- /dev/null +++ b/python/gtsam/tests/test_DsfTrackGenerator.py @@ -0,0 +1,96 @@ +"""Unit tests for track generation using a Disjoint Set Forest data structure. + +Authors: John Lambert +""" + +import unittest + +import gtsam +import numpy as np +from gtsam import IndexPair, KeypointsVector, MatchIndicesMap, Point2, SfmMeasurementVector, SfmTrack2d +from gtsam.gtsfm import Keypoints +from gtsam.utils.test_case import GtsamTestCase + + +class TestDsfTrackGenerator(GtsamTestCase): + """Tests for DsfTrackGenerator.""" + + def test_track_generation(self) -> None: + """Ensures that DSF generates three tracks from measurements + in 3 images (H=200,W=400).""" + kps_i0 = Keypoints(np.array([[10.0, 20], [30, 40]])) + kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]])) + kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]])) + + keypoints_list = KeypointsVector() + keypoints_list.append(kps_i0) + keypoints_list.append(kps_i1) + keypoints_list.append(kps_i2) + + # For each image pair (i1,i2), we provide a (K,2) matrix + # of corresponding image indices (k1,k2). + matches_dict = MatchIndicesMap() + matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) + matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) + + tracks = gtsam.gtsfm.tracksFromPairwiseMatches( + matches_dict, + keypoints_list, + verbose=False, + ) + assert len(tracks) == 3 + + # Verify track 0. + track0 = tracks[0] + assert track0.numberMeasurements() == 2 + np.testing.assert_allclose(track0.measurements[0][1], Point2(10, 20)) + np.testing.assert_allclose(track0.measurements[1][1], Point2(50, 60)) + assert track0.measurements[0][0] == 0 + assert track0.measurements[1][0] == 1 + np.testing.assert_allclose( + track0.measurementMatrix(), + [ + [10, 20], + [50, 60], + ], + ) + np.testing.assert_allclose(track0.indexVector(), [0, 1]) + + # Verify track 1. + track1 = tracks[1] + np.testing.assert_allclose( + track1.measurementMatrix(), + [ + [30, 40], + [70, 80], + [130, 140], + ], + ) + np.testing.assert_allclose(track1.indexVector(), [0, 1, 2]) + + # Verify track 2. + track2 = tracks[2] + np.testing.assert_allclose( + track2.measurementMatrix(), + [ + [90, 100], + [110, 120], + ], + ) + np.testing.assert_allclose(track2.indexVector(), [1, 2]) + + +class TestSfmTrack2d(GtsamTestCase): + """Tests for SfmTrack2d.""" + + def test_sfm_track_2d_constructor(self) -> None: + """ """ + measurements = SfmMeasurementVector() + measurements.append((0, Point2(10, 20))) + track = SfmTrack2d(measurements=measurements) + track.measurement(0) + track.numberMeasurements() == 1 + + +if __name__ == "__main__": + unittest.main()