Merge branch 'develop' into fix/doxygen

release/4.3a0
Varun Agrawal 2022-10-25 17:44:34 -04:00
commit 33e24265bd
60 changed files with 2593 additions and 978 deletions

View File

@ -101,8 +101,6 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX)
# Copy matlab.h to the correct folder. # Copy matlab.h to the correct folder.
configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h
${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY) ${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY)
# Add the include directories so that matlab.h can be found
include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}")
add_subdirectory(wrap) add_subdirectory(wrap)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake")

View File

@ -21,6 +21,10 @@ else()
find_dependency(Boost @BOOST_FIND_MINIMUM_VERSION@ COMPONENTS @BOOST_FIND_MINIMUM_COMPONENTS@) find_dependency(Boost @BOOST_FIND_MINIMUM_VERSION@ COMPONENTS @BOOST_FIND_MINIMUM_COMPONENTS@)
endif() endif()
if(@GTSAM_USE_SYSTEM_EIGEN@)
find_dependency(Eigen3 REQUIRED)
endif()
# Load exports # Load exports
include(${OUR_CMAKE_DIR}/@PACKAGE_NAME@-exports.cmake) include(${OUR_CMAKE_DIR}/@PACKAGE_NAME@-exports.cmake)

View File

@ -1,81 +0,0 @@
# - Try to find Eigen3 lib
#
# This module supports requiring a minimum version, e.g. you can do
# find_package(Eigen3 3.1.2)
# to require version 3.1.2 or newer of Eigen3.
#
# Once done this will define
#
# EIGEN3_FOUND - system has eigen lib with correct version
# EIGEN3_INCLUDE_DIR - the eigen include directory
# EIGEN3_VERSION - eigen version
# Copyright (c) 2006, 2007 Montel Laurent, <montel@kde.org>
# Copyright (c) 2008, 2009 Gael Guennebaud, <g.gael@free.fr>
# Copyright (c) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
# Redistribution and use is allowed according to the terms of the 2-clause BSD license.
if(NOT Eigen3_FIND_VERSION)
if(NOT Eigen3_FIND_VERSION_MAJOR)
set(Eigen3_FIND_VERSION_MAJOR 2)
endif(NOT Eigen3_FIND_VERSION_MAJOR)
if(NOT Eigen3_FIND_VERSION_MINOR)
set(Eigen3_FIND_VERSION_MINOR 91)
endif(NOT Eigen3_FIND_VERSION_MINOR)
if(NOT Eigen3_FIND_VERSION_PATCH)
set(Eigen3_FIND_VERSION_PATCH 0)
endif(NOT Eigen3_FIND_VERSION_PATCH)
set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}")
endif(NOT Eigen3_FIND_VERSION)
macro(_eigen3_check_version)
file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header)
string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}")
set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}")
string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}")
set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}")
string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}")
set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}")
set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION})
if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
set(EIGEN3_VERSION_OK FALSE)
else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
set(EIGEN3_VERSION_OK TRUE)
endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
if(NOT EIGEN3_VERSION_OK)
message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, "
"but at least version ${Eigen3_FIND_VERSION} is required")
endif(NOT EIGEN3_VERSION_OK)
endmacro(_eigen3_check_version)
if (EIGEN3_INCLUDE_DIR)
# in cache already
_eigen3_check_version()
set(EIGEN3_FOUND ${EIGEN3_VERSION_OK})
else (EIGEN3_INCLUDE_DIR)
find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library
PATHS
${CMAKE_INSTALL_PREFIX}/include
${KDE4_INCLUDE_DIR}
PATH_SUFFIXES eigen3 eigen
)
if(EIGEN3_INCLUDE_DIR)
_eigen3_check_version()
endif(EIGEN3_INCLUDE_DIR)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK)
mark_as_advanced(EIGEN3_INCLUDE_DIR)
endif(EIGEN3_INCLUDE_DIR)

View File

@ -1,7 +1,7 @@
############################################################################### ###############################################################################
# Option for using system Eigen or GTSAM-bundled Eigen # Option for using system Eigen or GTSAM-bundled Eigen
# Default: Use system's Eigen if found automatically: # Default: Use system's Eigen if found automatically:
find_package(Eigen3 QUIET) find_package(Eigen3 CONFIG QUIET)
set(USE_SYSTEM_EIGEN_INITIAL_VALUE ${Eigen3_FOUND}) set(USE_SYSTEM_EIGEN_INITIAL_VALUE ${Eigen3_FOUND})
option(GTSAM_USE_SYSTEM_EIGEN "Find and use system-installed Eigen. If 'off', use the one bundled with GTSAM" ${USE_SYSTEM_EIGEN_INITIAL_VALUE}) option(GTSAM_USE_SYSTEM_EIGEN "Find and use system-installed Eigen. If 'off', use the one bundled with GTSAM" ${USE_SYSTEM_EIGEN_INITIAL_VALUE})
unset(USE_SYSTEM_EIGEN_INITIAL_VALUE) unset(USE_SYSTEM_EIGEN_INITIAL_VALUE)
@ -14,10 +14,14 @@ endif()
# Switch for using system Eigen or GTSAM-bundled Eigen # Switch for using system Eigen or GTSAM-bundled Eigen
if(GTSAM_USE_SYSTEM_EIGEN) if(GTSAM_USE_SYSTEM_EIGEN)
find_package(Eigen3 REQUIRED) # need to find again as REQUIRED # Since Eigen 3.3.0 a Eigen3Config.cmake is available so use it.
find_package(Eigen3 CONFIG REQUIRED) # need to find again as REQUIRED
# Use generic Eigen include paths e.g. <Eigen/Core> # The actual include directory (for BUILD cmake target interface):
set(GTSAM_EIGEN_INCLUDE_FOR_INSTALL "${EIGEN3_INCLUDE_DIR}") # Note: EIGEN3_INCLUDE_DIR points to some random location on some eigen
# versions. So here I use the target itself to get the proper include
# directory (it is generated by cmake, thus has the correct path)
get_target_property(GTSAM_EIGEN_INCLUDE_FOR_BUILD Eigen3::Eigen INTERFACE_INCLUDE_DIRECTORIES)
# check if MKL is also enabled - can have one or the other, but not both! # check if MKL is also enabled - can have one or the other, but not both!
# Note: Eigen >= v3.2.5 includes our patches # Note: Eigen >= v3.2.5 includes our patches
@ -30,9 +34,6 @@ if(GTSAM_USE_SYSTEM_EIGEN)
if(EIGEN_USE_MKL_ALL AND (EIGEN3_VERSION VERSION_EQUAL 3.3.4)) if(EIGEN_USE_MKL_ALL AND (EIGEN3_VERSION VERSION_EQUAL 3.3.4))
message(FATAL_ERROR "MKL does not work with Eigen 3.3.4 because of a bug in Eigen. See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1527. Disable GTSAM_USE_SYSTEM_EIGEN to use GTSAM's copy of Eigen, disable GTSAM_WITH_EIGEN_MKL, or upgrade/patch your installation of Eigen.") message(FATAL_ERROR "MKL does not work with Eigen 3.3.4 because of a bug in Eigen. See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1527. Disable GTSAM_USE_SYSTEM_EIGEN to use GTSAM's copy of Eigen, disable GTSAM_WITH_EIGEN_MKL, or upgrade/patch your installation of Eigen.")
endif() endif()
# The actual include directory (for BUILD cmake target interface):
set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${EIGEN3_INCLUDE_DIR}")
else() else()
# Use bundled Eigen include path. # Use bundled Eigen include path.
# Clear any variables set by FindEigen3 # Clear any variables set by FindEigen3
@ -46,6 +47,19 @@ else()
# The actual include directory (for BUILD cmake target interface): # The actual include directory (for BUILD cmake target interface):
set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${GTSAM_SOURCE_DIR}/gtsam/3rdparty/Eigen/") set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${GTSAM_SOURCE_DIR}/gtsam/3rdparty/Eigen/")
add_library(gtsam_eigen3 INTERFACE)
target_include_directories(gtsam_eigen3 INTERFACE
$<BUILD_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_BUILD}>
$<INSTALL_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_INSTALL}>
)
add_library(Eigen3::Eigen ALIAS gtsam_eigen3)
install(TARGETS gtsam_eigen3 EXPORT GTSAM-exports PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
list(APPEND GTSAM_EXPORTED_TARGETS gtsam_eigen3)
set(GTSAM_EXPORTED_TARGETS "${GTSAM_EXPORTED_TARGETS}" PARENT_SCOPE)
endif() endif()
# Detect Eigen version: # Detect Eigen version:

View File

@ -117,12 +117,9 @@ set_target_properties(gtsam PROPERTIES
VERSION ${gtsam_version} VERSION ${gtsam_version}
SOVERSION ${gtsam_soversion}) SOVERSION ${gtsam_soversion})
# Append Eigen include path, set in top-level CMakeLists.txt to either
# system-eigen, or GTSAM eigen path # system-eigen, or GTSAM eigen path
target_include_directories(gtsam PUBLIC target_link_libraries(gtsam PUBLIC Eigen3::Eigen)
$<BUILD_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_BUILD}>
$<INSTALL_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_INSTALL}>
)
# MKL include dir: # MKL include dir:
if (GTSAM_USE_EIGEN_MKL) if (GTSAM_USE_EIGEN_MKL)
target_include_directories(gtsam PUBLIC ${MKL_INCLUDE_DIR}) target_include_directories(gtsam PUBLIC ${MKL_INCLUDE_DIR})

View File

@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str,
PrintForestVisitorPre visitor(keyFormatter); PrintForestVisitorPre visitor(keyFormatter);
DepthFirstForest(forest, str, visitor); DepthFirstForest(forest, str, visitor);
} }
} } // namespace treeTraversal
} } // namespace gtsam

View File

@ -11,15 +11,17 @@
/** /**
* @file Assignment.h * @file Assignment.h
* @brief An assignment from labels to a discrete value index (size_t) * @brief An assignment from labels to a discrete value index (size_t)
* @author Frank Dellaert * @author Frank Dellaert
* @date Feb 5, 2012 * @date Feb 5, 2012
*/ */
#pragma once #pragma once
#include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -33,13 +35,30 @@ namespace gtsam {
*/ */
template <class L> template <class L>
class Assignment : public std::map<L, size_t> { class Assignment : public std::map<L, size_t> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public: public:
using std::map<L, size_t>::operator=; using std::map<L, size_t>::operator=;
void print(const std::string& s = "Assignment: ") const { void print(const std::string& s = "Assignment: ",
const std::function<std::string(L)>& labelFormatter =
&DefaultFormatter) const {
std::cout << s << ": "; std::cout << s << ": ";
for (const typename Assignment::value_type& keyValue : *this) for (const typename Assignment::value_type& keyValue : *this) {
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; std::cout << "(" << labelFormatter(keyValue.first) << ", "
<< keyValue.second << ")";
}
std::cout << std::endl; std::cout << std::endl;
} }

View File

@ -48,4 +48,25 @@ namespace gtsam {
return keys & key2; return keys & key2;
} }
void DiscreteKeys::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
}
bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
if (this->size() != other.size()) {
return false;
}
for (size_t i = 0; i < this->size(); i++) {
if (this->at(i).first != other.at(i).first ||
this->at(i).second != other.at(i).second) {
return false;
}
}
return true;
}
} }

View File

@ -21,6 +21,7 @@
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
#include <boost/serialization/vector.hpp>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
@ -70,8 +71,30 @@ namespace gtsam {
push_back(key); push_back(key);
return *this; return *this;
} }
/// Print the keys and cardinalities.
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Check equality to another DiscreteKeys object.
bool equals(const DiscreteKeys& other, double tol = 0) const;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& boost::serialization::make_nvp(
"DiscreteKeys",
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
}
}; // DiscreteKeys }; // DiscreteKeys
/// Create a list from two keys /// Create a list from two keys
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
}
// traits
template <>
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};
} // namespace gtsam

View File

@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) {
clique->separatorMarginal(EliminateDiscrete); clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14) // check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9]; clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 = DiscreteFactorGraph separatorMarginal9 =

View File

@ -16,14 +16,29 @@
* @author Duy-Nguyen Ta * @author Duy-Nguyen Ta
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
using namespace boost::assign; using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
/* ************************************************************************* */
TEST(DisreteKeys, Serialization) {
DiscreteKeys keys;
keys& DiscreteKey(0, 2);
keys& DiscreteKey(1, 3);
keys& DiscreteKey(2, 4);
EXPECT(equalsObj<DiscreteKeys>(keys));
EXPECT(equalsXML<DiscreteKeys>(keys));
EXPECT(equalsBinary<DiscreteKeys>(keys));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
@ -31,4 +46,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -33,8 +33,6 @@ static const Point2 P(0.2, 0.7);
static const Rot2 R = Rot2::fromAngle(0.3); static const Rot2 R = Rot2::fromAngle(0.3);
static const double s = 4; static const double s = 4;
const double degree = M_PI / 180;
//****************************************************************************** //******************************************************************************
TEST(Similarity2, Concepts) { TEST(Similarity2, Concepts) {
BOOST_CONCEPT_ASSERT((IsGroup<Similarity2>)); BOOST_CONCEPT_ASSERT((IsGroup<Similarity2>));

View File

@ -66,6 +66,27 @@ class KeySet {
void serialize() const; void serialize() const;
}; };
// Actually a vector<Key>, needed for Matlab
class KeyVector {
KeyVector();
KeyVector(const gtsam::KeyVector& other);
// Note: no print function
// common STL methods
size_t size() const;
bool empty() const;
void clear();
// structure specific methods
size_t at(size_t i) const;
size_t front() const;
size_t back() const;
void push_back(size_t key) const;
void serialize() const;
};
// Actually a FastMap<Key,int> // Actually a FastMap<Key,int>
class KeyGroupMap { class KeyGroupMap {
KeyGroupMap(); KeyGroupMap();

View File

@ -119,33 +119,90 @@ void GaussianMixture::print(const std::string &s,
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string { [&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd; RedirectCout rd;
if (gf && !gf->empty()) if (gf && !gf->empty()) {
gf->print("", formatter); gf->print("", formatter);
else return rd.str();
return {"nullptr"}; } else {
return rd.str(); return "nullptr";
}
}); });
} }
/* *******************************************************************************/ /* ************************************************************************* */
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { /// Return the DiscreteKey vector as a set.
// Functional which loops over all assignments and create a set of std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
// GaussianConditionals std::set<DiscreteKey> s;
auto pruner = [&decisionTree]( s.insert(dkeys.begin(), dkeys.end());
return s;
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
if (decisionTree(values) == 0.0) { // Case where the gaussian mixture has the same
// empty aka null pointer // discrete keys as the decision tree.
boost::shared_ptr<GaussianConditional> null; if (gaussianMixtureKeySet == decisionTreeKeySet) {
return null; if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else { } else {
return conditional; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
gaussianMixtureKeySet.begin(),
gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (decisionTree(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
} }
}; };
return pruner;
}
/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(decisionTree);
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;

View File

@ -70,6 +70,17 @@ class GTSAM_EXPORT GaussianMixture
*/ */
Sum asGaussianFactorGraphTree() const; Sum asGaussianFactorGraphTree() const;
/**
* @brief Helper function to get the pruner functor.
*
* @param decisionTree The pruned discrete probability decision tree.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree);
public: public:
/// @name Constructors /// @name Constructors
/// @{ /// @{

View File

@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s,
[&](const GaussianFactor::shared_ptr &gf) -> std::string { [&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf) if (gf && !gf->empty()) {
gf->print("", formatter); gf->print("", formatter);
else return rd.str();
return {"nullptr"}; } else {
return rd.str(); return "nullptr";
}
}); });
std::cout << "}" << std::endl; std::cout << "}" << std::endl;
} }

View File

@ -15,23 +15,40 @@
* @date January 2022 * @date January 2022
*/ */
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
/// Return the DiscreteKey vector as a set. DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { AlgebraicDecisionTree<Key> decisionTree;
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end()); // The canonical decision tree factor which will get the discrete conditionals
return s; // added to it.
DecisionTreeFactor dtFactor;
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
dtFactor = dtFactor * f;
}
}
return boost::make_shared<DecisionTreeFactor>(dtFactor);
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune( HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
const DecisionTreeFactor::shared_ptr &discreteFactor) const { // Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
/* To Prune, we visitWith every leaf in the GaussianMixture. /* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.
@ -41,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune(
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor. // Bayes Net and prune them as per discreteFactor.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); HybridConditional::shared_ptr conditional = this->at(i);
GaussianMixture::shared_ptr gaussianMixture = if (conditional->isHybrid()) {
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner()); GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
if (gaussianMixture) { // Make a copy of the gaussian mixture and prune it!
// We may have mixtures with less discrete keys than discreteFactor so we auto prunedGaussianMixture =
// skip those since the label assignment does not exist. boost::make_shared<GaussianMixture>(*gaussianMixture);
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); prunedGaussianMixture->prune(*discreteFactor);
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
if (gmKeySet != dfKeySet) {
// Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture));
continue;
}
// Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());
// Convert from boost::iterator_range to KeyVector
// so we can pass it to constructor.
KeyVector frontals(gaussianMixture->frontals().begin(),
gaussianMixture->frontals().end()),
parents(gaussianMixture->parents().begin(),
gaussianMixture->parents().end());
// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
frontals, parents, discreteKeys, prunedTree);
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(
@ -111,14 +85,18 @@ HybridBayesNet HybridBayesNet::prune(
} }
/* ************************************************************************* */ /* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner()); return factors_.at(i)->asMixture();
}
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return boost::dynamic_pointer_cast<DiscreteConditional>( return factors_.at(i)->asDiscreteConditional();
factors_.at(i)->inner());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -126,16 +104,45 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
GaussianBayesNet gbn; GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) { for (size_t idx = 0; idx < size(); idx++) {
GaussianMixture gm = *this->atGaussian(idx); if (factors_.at(idx)->isHybrid()) {
gbn.push_back(gm(assignment)); // If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
gbn.push_back(gm(assignment));
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx)));
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue;
}
} }
return gbn; return gbn;
} }
/* *******************************************************************************/ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this); // Solve for the MPE
return dag.argmax(); DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
}
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
return HybridValues(mpe, gbn.optimize());
}
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize();
} }
} // namespace gtsam } // namespace gtsam

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/global_includes.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
@ -39,12 +40,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using shared_ptr = boost::shared_ptr<HybridBayesNet>; using shared_ptr = boost::shared_ptr<HybridBayesNet>;
using sharedConditional = boost::shared_ptr<ConditionalType>; using sharedConditional = boost::shared_ptr<ConditionalType>;
/// @name Standard Constructors
/// @{
/** Construct empty bayes net */ /** Construct empty bayes net */
HybridBayesNet() = default; HybridBayesNet() = default;
/// Prune the Hybrid Bayes Net given the discrete decision tree. /// @}
HybridBayesNet prune( /// @name Testable
const DecisionTreeFactor::shared_ptr &discreteFactor) const; /// @{
/** Check equality */
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}
/// print graph
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/// @}
/// @name Standard Interface
/// @{
/// Add HybridConditional to Bayes Net /// Add HybridConditional to Bayes Net
using Base::add; using Base::add;
@ -55,8 +75,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
HybridConditional(boost::make_shared<DiscreteConditional>(key, table))); HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
} }
using Base::push_back;
/// Get a specific Gaussian mixture by index `i`. /// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const; GaussianMixture::shared_ptr atMixture(size_t i) const;
/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;
/// Get a specific discrete conditional by index `i`. /// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const; DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
@ -70,10 +95,49 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
GaussianBayesNet choose(const DiscreteValues &assignment) const; GaussianBayesNet choose(const DiscreteValues &assignment) const;
/// Solve the HybridBayesNet by back-substitution. /**
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and * @brief Solve the HybridBayesNet by first computing the MPE of all the
/// put this method there? * discrete variables and then optimizing the continuous variables based on
* the MPE assignment.
*
* @return HybridValues
*/
HybridValues optimize() const; HybridValues optimize() const;
/**
* @brief Given the discrete assignment, return the optimized estimate for the
* selected Gaussian BayesNet.
*
* @param assignment An assignment of discrete values.
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;
protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;
public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
}; };
/// traits
template <>
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
} // namespace gtsam } // namespace gtsam

View File

@ -18,10 +18,13 @@
*/ */
#include <gtsam/base/treeTraversal-inst.h> #include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h> #include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h> #include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/linear/GaussianJunctionTree.h>
namespace gtsam { namespace gtsam {
@ -35,4 +38,161 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol); return Base::equals(other, tol);
} }
/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
DiscreteBayesNet dbn;
DiscreteValues mpe;
auto root = roots_.at(0);
// Access the clique and get the underlying hybrid conditional
HybridConditional::shared_ptr root_conditional = root->conditional();
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional());
mpe = DiscreteFactorGraph(dbn).optimize();
} else {
throw std::runtime_error(
"HybridBayesTree root is not discrete-only. Please check elimination "
"ordering or use continuous factor graph.");
}
VectorValues values = optimize(mpe);
return HybridValues(mpe, values);
}
/* ************************************************************************* */
/**
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
*
* When traversing the tree, the pre-order visitor will receive an instance of
* this class with the parent clique data.
*/
struct HybridAssignmentData {
const DiscreteValues assignment_;
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
/**
* @brief Construct a new Hybrid Assignment Data object.
*
* @param assignment The MPE assignment for the optimal Gaussian cliques.
* @param parentClique The clique from the parent node of the current node.
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The HybridAssignmentData from the parent node.
* @return HybridAssignmentData which is passed to the children.
*/
static HybridAssignmentData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& node,
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
} else if (hybrid_conditional->isContinuous()) {
conditional = hybrid_conditional->asGaussian();
} else {
// Discrete only conditional, so we set to empty gaussian conditional
conditional = boost::make_shared<GaussianConditional>();
}
// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
return data;
}
};
/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
visitorPost);
}
VectorValues result = gbt.optimize();
// Return the optimized bayes net result.
return result;
}
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
const HybridBayesTree::sharedNode& parentClique)
: prunedDecisionTree(prunedDecisionTree) {}
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The data from the parent node.
* @return HybridPrunerData which is passed to the children.
*/
static HybridPrunerData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& clique,
HybridPrunerData& parentData) {
// Get the conditional
HybridConditional::shared_ptr conditional = clique->conditional();
// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
gaussianMixture->prune(parentData.prunedDecisionTree);
}
return parentData;
}
};
HybridPrunerData rootData(prunedDecisionTree, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
visitorPost);
}
}
} // namespace gtsam } // namespace gtsam

View File

@ -73,9 +73,46 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */ /** Check equality */
bool equals(const This& other, double tol = 1e-9) const; bool equals(const This& other, double tol = 1e-9) const;
/**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous
* update delta.
*
* @return HybridValues
*/
HybridValues optimize() const;
/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
* @param assignment The discrete values assignment to select the Gaussian
* mixtures.
* @return VectorValues
*/
VectorValues optimize(const DiscreteValues& assignment) const;
/**
* @brief Prune the underlying Bayes tree.
*
* @param maxNumberLeaves The max number of leaf nodes to keep.
*/
void prune(const size_t maxNumberLeaves);
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
}; };
/// traits
template <>
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
/** /**
* @brief Class for Hybrid Bayes tree orphan subtrees. * @brief Class for Hybrid Bayes tree orphan subtrees.
* *

View File

@ -34,8 +34,6 @@
namespace gtsam { namespace gtsam {
class HybridGaussianFactorGraph;
/** /**
* Hybrid Conditional Density * Hybrid Conditional Density
* *
@ -71,7 +69,7 @@ class GTSAM_EXPORT HybridConditional
BaseConditional; ///< Typedef to our conditional base class BaseConditional; ///< Typedef to our conditional base class
protected: protected:
// Type-erased pointer to the inner type /// Type-erased pointer to the inner type
boost::shared_ptr<Factor> inner_; boost::shared_ptr<Factor> inner_;
public: public:
@ -129,8 +127,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the * @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional. * HybridConditional.
*/ */
HybridConditional( HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
boost::shared_ptr<GaussianMixture> gaussianMixture);
/** /**
* @brief Return HybridConditional as a GaussianMixture * @brief Return HybridConditional as a GaussianMixture
@ -142,6 +139,17 @@ class GTSAM_EXPORT HybridConditional
return boost::static_pointer_cast<GaussianMixture>(inner_); return boost::static_pointer_cast<GaussianMixture>(inner_);
} }
/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
}
/** /**
* @brief Return conditional as a DiscreteConditional * @brief Return conditional as a DiscreteConditional
* *
@ -170,10 +178,19 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; } boost::shared_ptr<Factor> inner() { return inner_; }
}; // DiscreteConditional private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
}; // HybridConditional
// traits // traits
template <> template <>
struct traits<HybridConditional> : public Testable<DiscreteConditional> {}; struct traits<HybridConditional> : public Testable<HybridConditional> {};
} // namespace gtsam } // namespace gtsam

View File

@ -50,10 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys) HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys), : Base(keys), isContinuous_(true), continuousKeys_(keys) {}
isContinuous_(true),
nrContinuous_(keys.size()),
continuousKeys_(keys) {}
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys, HybridFactor::HybridFactor(const KeyVector &continuousKeys,
@ -62,7 +59,6 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys,
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
nrContinuous_(continuousKeys.size()),
discreteKeys_(discreteKeys), discreteKeys_(discreteKeys),
continuousKeys_(continuousKeys) {} continuousKeys_(continuousKeys) {}
@ -103,7 +99,6 @@ void HybridFactor::print(const std::string &s,
if (d < discreteKeys_.size() - 1) { if (d < discreteKeys_.size() - 1) {
std::cout << " "; std::cout << " ";
} }
} }
std::cout << "]"; std::cout << "]";
} }

View File

@ -49,8 +49,6 @@ class GTSAM_EXPORT HybridFactor : public Factor {
bool isContinuous_ = false; bool isContinuous_ = false;
bool isHybrid_ = false; bool isHybrid_ = false;
size_t nrContinuous_ = 0;
protected: protected:
// Set of DiscreteKeys for this factor. // Set of DiscreteKeys for this factor.
DiscreteKeys discreteKeys_; DiscreteKeys discreteKeys_;
@ -131,6 +129,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
const KeyVector &continuousKeys() const { return continuousKeys_; } const KeyVector &continuousKeys() const { return continuousKeys_; }
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
}
}; };
// HybridFactor // HybridFactor

View File

@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
push_hybrid(p); push_hybrid(p);
} }
} }
/// Get all the discrete keys in the factor graph.
const KeySet discreteKeys() const {
KeySet discrete_keys;
for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}
/// Get all the continuous keys in the factor graph.
const KeySet continuousKeys() const {
KeySet keys;
for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals(
} }
} else if (f->isContinuous()) { } else if (f->isContinuous()) {
deferredFactors.push_back( if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner()); deferredFactors.push_back(gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian());
}
} else if (f->isDiscrete()) { } else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
@ -135,9 +139,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
for (auto &fp : factors) { for (auto &fp : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) { if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
gfg.push_back(ptr->inner()); gfg.push_back(ptr->inner());
} else if (auto p = } else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
boost::static_pointer_cast<HybridConditional>(fp)->inner()) { gfg.push_back(
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p)); boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
} else { } else {
// It is an orphan wrapped conditional // It is an orphan wrapped conditional
} }
@ -153,12 +157,14 @@ std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &fp : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp)) { for (auto &factor : factors) {
dfg.push_back(ptr->inner()); if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) {
} else if (auto p = dfg.push_back(p->inner());
boost::static_pointer_cast<HybridConditional>(fp)->inner()) { } else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) {
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p)); auto discrete_conditional =
boost::static_pointer_cast<DiscreteConditional>(p->inner());
dfg.push_back(discrete_conditional);
} else { } else {
// It is an orphan wrapper // It is an orphan wrapper
} }
@ -213,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
result = EliminatePreferCholesky(graph, frontalKeys); result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) { if (keysOfEliminated.empty()) {
keysOfEliminated = // Initialize the keysOfEliminated to be the keys of the
result.first->keys(); // Initialize the keysOfEliminated to be the // eliminated GaussianConditional
keysOfEliminated = result.first->keys();
} }
// keysOfEliminated of the GaussianConditional
if (keysOfSeparator.empty()) { if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys(); keysOfSeparator = result.second->keys();
} }
@ -244,6 +250,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
return exp(-factor->error(empty_values)); return exp(-factor->error(empty_values));
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorError); DecisionTree<Key, double> fdt(separatorFactors, factorError);
auto discreteFactor = auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
@ -401,4 +408,19 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor)); FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
} }
/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}
} // namespace gtsam } // namespace gtsam

View File

@ -169,6 +169,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
Base::push_back(sharedFactor); Base::push_back(sharedFactor);
} }
} }
/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
const Ordering getHybridOrdering() const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -14,9 +14,10 @@
* @date March 31, 2022 * @date March 31, 2022
* @author Fan Jiang * @author Fan Jiang
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Varun Agrawal
*/ */
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianISAM.h> #include <gtsam/hybrid/HybridGaussianISAM.h>
@ -41,6 +42,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree)
void HybridGaussianISAM::updateInternal( void HybridGaussianISAM::updateInternal(
const HybridGaussianFactorGraph& newFactors, const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans, HybridBayesTree::Cliques* orphans,
const boost::optional<size_t>& maxNrLeaves,
const boost::optional<Ordering>& ordering, const boost::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) { const HybridBayesTree::Eliminate& function) {
// Remove the contaminated part of the Bayes tree // Remove the contaminated part of the Bayes tree
@ -57,26 +59,28 @@ void HybridGaussianISAM::updateInternal(
factors += newFactors; factors += newFactors;
// Add the orphaned subtrees // Add the orphaned subtrees
for (const sharedClique& orphan : *orphans) for (const sharedClique& orphan : *orphans) {
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan); factors += boost::make_shared<BayesTreeOrphanWrapper<Node>>(orphan);
KeySet allDiscrete;
for (auto& factor : factors) {
for (auto& k : factor->discreteKeys()) {
allDiscrete.insert(k.first);
}
} }
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast; KeyVector newKeysDiscreteLast;
// Insert continuous keys first.
for (auto& k : newFactorKeys) { for (auto& k : newFactorKeys) {
if (!allDiscrete.exists(k)) { if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k); newKeysDiscreteLast.push_back(k);
} }
} }
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(), std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast)); std::back_inserter(newKeysDiscreteLast));
// Get an ordering where the new keys are eliminated last // Get an ordering where the new keys are eliminated last
const VariableIndex index(factors); const VariableIndex index(factors);
Ordering elimination_ordering; Ordering elimination_ordering;
if (ordering) { if (ordering) {
elimination_ordering = *ordering; elimination_ordering = *ordering;
@ -91,6 +95,10 @@ void HybridGaussianISAM::updateInternal(
HybridBayesTree::shared_ptr bayesTree = HybridBayesTree::shared_ptr bayesTree =
factors.eliminateMultifrontal(elimination_ordering, function, index); factors.eliminateMultifrontal(elimination_ordering, function, index);
if (maxNrLeaves) {
bayesTree->prune(*maxNrLeaves);
}
// Re-add into Bayes tree data structures // Re-add into Bayes tree data structures
this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(), this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(),
bayesTree->roots().end()); bayesTree->roots().end());
@ -99,61 +107,11 @@ void HybridGaussianISAM::updateInternal(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
const boost::optional<size_t>& maxNrLeaves,
const boost::optional<Ordering>& ordering, const boost::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) { const HybridBayesTree::Eliminate& function) {
Cliques orphans; Cliques orphans;
this->updateInternal(newFactors, &orphans, ordering, function); this->updateInternal(newFactors, &orphans, maxNrLeaves, ordering, function);
}
/* ************************************************************************* */
/**
* @brief Check if `b` is a subset of `a`.
* Non-const since they need to be sorted.
*
* @param a KeyVector
* @param b KeyVector
* @return True if the keys of b is a subset of a, else false.
*/
bool IsSubset(KeyVector a, KeyVector b) {
std::sort(a.begin(), a.end());
std::sort(b.begin(), b.end());
return std::includes(a.begin(), a.end(), b.begin(), b.end());
}
/* ************************************************************************* */
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->clique(root)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
std::vector<gtsam::Key> prunedKeys;
for (auto&& clique : nodes()) {
// The cliques can be repeated for each frontal so we record it in
// prunedKeys and check if we have already pruned a particular clique.
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
prunedKeys.end()) {
continue;
}
// Add all the keys of the current clique to be pruned to prunedKeys
for (auto&& key : clique.second->conditional()->frontals()) {
prunedKeys.push_back(key);
}
// Convert parents() to a KeyVector for comparison
KeyVector parents;
for (auto&& parent : clique.second->conditional()->parents()) {
parents.push_back(parent);
}
if (IsSubset(parents, decisionTree->keys())) {
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
clique.second->conditional()->inner());
gaussianMixture->prune(prunedDiscreteFactor);
}
}
} }
} // namespace gtsam } // namespace gtsam

View File

@ -53,6 +53,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
void updateInternal( void updateInternal(
const HybridGaussianFactorGraph& newFactors, const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans, HybridBayesTree::Cliques* orphans,
const boost::optional<size_t>& maxNrLeaves = boost::none,
const boost::optional<Ordering>& ordering = boost::none, const boost::optional<Ordering>& ordering = boost::none,
const HybridBayesTree::Eliminate& function = const HybridBayesTree::Eliminate& function =
HybridBayesTree::EliminationTraitsType::DefaultEliminate); HybridBayesTree::EliminationTraitsType::DefaultEliminate);
@ -62,20 +63,15 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
* @brief Perform update step with new factors. * @brief Perform update step with new factors.
* *
* @param newFactors Factor graph of new factors to add and eliminate. * @param newFactors Factor graph of new factors to add and eliminate.
* @param maxNrLeaves The maximum number of leaves to keep after pruning.
* @param ordering Custom elimination ordering.
* @param function Elimination function. * @param function Elimination function.
*/ */
void update(const HybridGaussianFactorGraph& newFactors, void update(const HybridGaussianFactorGraph& newFactors,
const boost::optional<size_t>& maxNrLeaves = boost::none,
const boost::optional<Ordering>& ordering = boost::none, const boost::optional<Ordering>& ordering = boost::none,
const HybridBayesTree::Eliminate& function = const HybridBayesTree::Eliminate& function =
HybridBayesTree::EliminationTraitsType::DefaultEliminate); HybridBayesTree::EliminationTraitsType::DefaultEliminate);
/**
* @brief
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
*/
void prune(const Key& root, const size_t maxNumberLeaves);
}; };
/// traits /// traits

View File

@ -31,9 +31,7 @@ template class EliminatableClusterTree<HybridBayesTree,
template class JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>; template class JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>;
struct HybridConstructorTraversalData { struct HybridConstructorTraversalData {
typedef typedef HybridJunctionTree::Node Node;
typename JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>::Node
Node;
typedef typedef
typename JunctionTree<HybridBayesTree, typename JunctionTree<HybridBayesTree,
HybridGaussianFactorGraph>::sharedNode sharedNode; HybridGaussianFactorGraph>::sharedNode sharedNode;
@ -62,6 +60,7 @@ struct HybridConstructorTraversalData {
data.junctionTreeNode = boost::make_shared<Node>(node->key, node->factors); data.junctionTreeNode = boost::make_shared<Node>(node->key, node->factors);
parentData.junctionTreeNode->addChild(data.junctionTreeNode); parentData.junctionTreeNode->addChild(data.junctionTreeNode);
// Add all the discrete keys in the hybrid factors to the current data
for (HybridFactor::shared_ptr& f : node->factors) { for (HybridFactor::shared_ptr& f : node->factors) {
for (auto& k : f->discreteKeys()) { for (auto& k : f->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
@ -72,8 +71,8 @@ struct HybridConstructorTraversalData {
} }
// Post-order visitor function // Post-order visitor function
static void ConstructorTraversalVisitorPostAlg2( static void ConstructorTraversalVisitorPost(
const boost::shared_ptr<HybridEliminationTree::Node>& ETreeNode, const boost::shared_ptr<HybridEliminationTree::Node>& node,
const HybridConstructorTraversalData& data) { const HybridConstructorTraversalData& data) {
// In this post-order visitor, we combine the symbolic elimination results // In this post-order visitor, we combine the symbolic elimination results
// from the elimination tree children and symbolically eliminate the current // from the elimination tree children and symbolically eliminate the current
@ -86,15 +85,15 @@ struct HybridConstructorTraversalData {
// Do symbolic elimination for this node // Do symbolic elimination for this node
SymbolicFactors symbolicFactors; SymbolicFactors symbolicFactors;
symbolicFactors.reserve(ETreeNode->factors.size() + symbolicFactors.reserve(node->factors.size() +
data.childSymbolicFactors.size()); data.childSymbolicFactors.size());
// Add ETree node factors // Add ETree node factors
symbolicFactors += ETreeNode->factors; symbolicFactors += node->factors;
// Add symbolic factors passed up from children // Add symbolic factors passed up from children
symbolicFactors += data.childSymbolicFactors; symbolicFactors += data.childSymbolicFactors;
Ordering keyAsOrdering; Ordering keyAsOrdering;
keyAsOrdering.push_back(ETreeNode->key); keyAsOrdering.push_back(node->key);
SymbolicConditional::shared_ptr conditional; SymbolicConditional::shared_ptr conditional;
SymbolicFactor::shared_ptr separatorFactor; SymbolicFactor::shared_ptr separatorFactor;
boost::tie(conditional, separatorFactor) = boost::tie(conditional, separatorFactor) =
@ -105,19 +104,19 @@ struct HybridConstructorTraversalData {
data.parentData->childSymbolicFactors.push_back(separatorFactor); data.parentData->childSymbolicFactors.push_back(separatorFactor);
data.parentData->discreteKeys.merge(data.discreteKeys); data.parentData->discreteKeys.merge(data.discreteKeys);
sharedNode node = data.junctionTreeNode; sharedNode jt_node = data.junctionTreeNode;
const FastVector<SymbolicConditional::shared_ptr>& childConditionals = const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
data.childSymbolicConditionals; data.childSymbolicConditionals;
node->problemSize_ = (int)(conditional->size() * symbolicFactors.size()); jt_node->problemSize_ = (int)(conditional->size() * symbolicFactors.size());
// Merge our children if they are in our clique - if our conditional has // Merge our children if they are in our clique - if our conditional has
// exactly one fewer parent than our child's conditional. // exactly one fewer parent than our child's conditional.
const size_t nrParents = conditional->nrParents(); const size_t nrParents = conditional->nrParents();
const size_t nrChildren = node->nrChildren(); const size_t nrChildren = jt_node->nrChildren();
assert(childConditionals.size() == nrChildren); assert(childConditionals.size() == nrChildren);
// decide which children to merge, as index into children // decide which children to merge, as index into children
std::vector<size_t> nrChildrenFrontals = node->nrFrontalsOfChildren(); std::vector<size_t> nrChildrenFrontals = jt_node->nrFrontalsOfChildren();
std::vector<bool> merge(nrChildren, false); std::vector<bool> merge(nrChildren, false);
size_t nrFrontals = 1; size_t nrFrontals = 1;
for (size_t i = 0; i < nrChildren; i++) { for (size_t i = 0; i < nrChildren; i++) {
@ -137,7 +136,7 @@ struct HybridConstructorTraversalData {
} }
// now really merge // now really merge
node->mergeChildren(merge); jt_node->mergeChildren(merge);
} }
}; };
@ -161,7 +160,7 @@ HybridJunctionTree::HybridJunctionTree(
// the junction tree roots // the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPostAlg2); Data::ConstructorTraversalVisitorPost);
// Assign roots from the dummy node // Assign roots from the dummy node
this->addChildrenAsRoots(rootData.junctionTreeNode); this->addChildrenAsRoots(rootData.junctionTreeNode);

View File

@ -1,76 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/VectorValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
// For discrete conditional, uses argmaxInPlace() method in
// DiscreteLookupTable.
if (isDiscrete()) {
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
&(values->discrete));
} else if (isContinuous()) {
// For Gaussian conditional, uses solve() method in GaussianConditional.
values->continuous.insert(
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
values->continuous));
} else if (isHybrid()) {
// For hybrid conditional, since children should not contain discrete
// variable, we can condition on the discrete variable in the parents and
// solve the resulting GaussianConditional.
auto conditional =
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
values->discrete);
values->continuous.insert(conditional->solve(values->continuous));
}
}
/* ************************************************************************** */
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
HybridLookupDAG dag;
for (auto&& conditional : bayesNet) {
HybridLookupTable hlt(*conditional);
dag.push_back(hlt);
}
return dag;
}
/* ************************************************************************** */
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
} // namespace gtsam

View File

@ -1,119 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridLookupDAG.h
* @date Aug, 2022
* @author Shangjie Xue
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
/**
* @brief HybridLookupTable table for max-product
*
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
* convenience. Is used in the max-product algorithm.
*/
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
public:
using Base = HybridConditional;
using This = HybridLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
*
* @param conditional input hybrid conditional
*/
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(HybridValues* parentsValues) const;
};
/** A DAG made from hybrid lookup tables, as defined above. Similar to
* DiscreteLookupDAG */
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
public:
using Base = BayesNet<HybridLookupTable>;
using This = HybridLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
HybridLookupDAG() {}
/// Create from BayesNet with LookupTables
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);
/// Destructor
virtual ~HybridLookupDAG() {}
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
HybridValues argmax(HybridValues given = HybridValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};
} // namespace gtsam

View File

@ -27,8 +27,7 @@ void HybridNonlinearFactorGraph::add(
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridNonlinearFactorGraph::add( void HybridNonlinearFactorGraph::add(boost::shared_ptr<DiscreteFactor> factor) {
boost::shared_ptr<DiscreteFactor> factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor)); FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
} }
@ -49,12 +48,12 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize( HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const { const Values& continuousValues) const {
// create an empty linear FG // create an empty linear FG
HybridGaussianFactorGraph linearFG; auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();
linearFG.reserve(size()); linearFG->reserve(size());
// linearize all hybrid factors // linearize all hybrid factors
for (auto&& factor : factors_) { for (auto&& factor : factors_) {
@ -66,9 +65,9 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
if (factor->isHybrid()) { if (factor->isHybrid()) {
// Check if it is a nonlinear mixture factor // Check if it is a nonlinear mixture factor
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) { if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
linearFG.push_back(nlmf->linearize(continuousValues)); linearFG->push_back(nlmf->linearize(continuousValues));
} else { } else {
linearFG.push_back(factor); linearFG->push_back(factor);
} }
// Now check if the factor is a continuous only factor. // Now check if the factor is a continuous only factor.
@ -80,18 +79,18 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) { boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
auto hgf = boost::make_shared<HybridGaussianFactor>( auto hgf = boost::make_shared<HybridGaussianFactor>(
nlf->linearize(continuousValues)); nlf->linearize(continuousValues));
linearFG.push_back(hgf); linearFG->push_back(hgf);
} else { } else {
linearFG.push_back(factor); linearFG->push_back(factor);
} }
// Finally if nothing else, we are discrete-only which doesn't need // Finally if nothing else, we are discrete-only which doesn't need
// lineariztion. // lineariztion.
} else { } else {
linearFG.push_back(factor); linearFG->push_back(factor);
} }
} else { } else {
linearFG.push_back(GaussianFactor::shared_ptr()); linearFG->push_back(GaussianFactor::shared_ptr());
} }
} }
return linearFG; return linearFG;

View File

@ -42,6 +42,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
using IsNonlinear = typename std::enable_if< using IsNonlinear = typename std::enable_if<
std::is_base_of<NonlinearFactor, FACTOR>::value>::type; std::is_base_of<NonlinearFactor, FACTOR>::value>::type;
/// Check if T has a value_type derived from FactorType.
template <typename T>
using HasDerivedValueType = typename std::enable_if<
std::is_base_of<HybridFactor, typename T::value_type>::value>::type;
/// Check if T has a pointer type derived from FactorType.
template <typename T>
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
HybridFactor, typename T::value_type::element_type>::value>::type;
public: public:
using Base = HybridFactorGraph; using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class using This = HybridNonlinearFactorGraph; ///< this class
@ -109,6 +119,21 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
} }
} }
/**
* Push back many factors as shared_ptr's in a container (factors are not
* copied)
*/
template <typename CONTAINER>
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end());
}
/// Push back non-pointer objects in a container (factors are copied).
template <typename CONTAINER>
HasDerivedValueType<CONTAINER> push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end());
}
/// Add a nonlinear factor as a shared ptr. /// Add a nonlinear factor as a shared ptr.
void add(boost::shared_ptr<NonlinearFactor> factor); void add(boost::shared_ptr<NonlinearFactor> factor);
@ -127,7 +152,8 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
* @param continuousValues: Dictionary of continuous values. * @param continuousValues: Dictionary of continuous values.
* @return HybridGaussianFactorGraph::shared_ptr * @return HybridGaussianFactorGraph::shared_ptr
*/ */
HybridGaussianFactorGraph linearize(const Values& continuousValues) const; HybridGaussianFactorGraph::shared_ptr linearize(
const Values& continuousValues) const;
}; };
template <> template <>

View File

@ -0,0 +1,114 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridNonlinearISAM.cpp
* @date Sep 12, 2022
* @author Varun Agrawal
*/
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h>
#include <gtsam/inference/Ordering.h>
#include <iostream>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
void HybridNonlinearISAM::saveGraph(const string& s,
const KeyFormatter& keyFormatter) const {
isam_.saveGraph(s, keyFormatter);
}
/* ************************************************************************* */
void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
const Values& initialValues,
const boost::optional<size_t>& maxNrLeaves,
const boost::optional<Ordering>& ordering) {
if (newFactors.size() > 0) {
// Reorder and relinearize every reorderInterval updates
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
reorder_relinearize();
reorderCounter_ = 0;
}
factors_.push_back(newFactors);
// Linearize new factors and insert them
// TODO: optimize for whole config?
linPoint_.insert(initialValues);
boost::shared_ptr<HybridGaussianFactorGraph> linearizedNewFactors =
newFactors.linearize(linPoint_);
// Update ISAM
isam_.update(*linearizedNewFactors, maxNrLeaves, ordering,
eliminationFunction_);
}
}
/* ************************************************************************* */
void HybridNonlinearISAM::reorder_relinearize() {
if (factors_.size() > 0) {
// Obtain the new linearization point
const Values newLinPoint = estimate();
isam_.clear();
// Just recreate the whole BayesTree
// TODO: allow for constrained ordering here
// TODO: decouple relinearization and reordering to avoid
isam_.update(*factors_.linearize(newLinPoint), boost::none, boost::none,
eliminationFunction_);
// Update linearization point
linPoint_ = newLinPoint;
}
}
/* ************************************************************************* */
Values HybridNonlinearISAM::estimate() {
Values result;
if (isam_.size() > 0) {
HybridValues values = isam_.optimize();
assignment_ = values.discrete();
return linPoint_.retract(values.continuous());
} else {
return linPoint_;
}
}
// /* *************************************************************************
// */ Matrix HybridNonlinearISAM::marginalCovariance(Key key) const {
// return isam_.marginalCovariance(key);
// }
/* ************************************************************************* */
void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n", keyFormatter);
linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter);
}
/* ************************************************************************* */
void HybridNonlinearISAM::printStats() const {
isam_.getCliqueData().getStats().print();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -0,0 +1,131 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridNonlinearISAM.h
* @date Sep 12, 2022
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
namespace gtsam {
/**
* Wrapper class to manage ISAM in a nonlinear context
*/
class GTSAM_EXPORT HybridNonlinearISAM {
protected:
/** The internal iSAM object */
gtsam::HybridGaussianISAM isam_;
/** The current linearization point */
Values linPoint_;
/// The discrete assignment
DiscreteValues assignment_;
/** The original factors, used when relinearizing */
HybridNonlinearFactorGraph factors_;
/** The reordering interval and counter */
int reorderInterval_;
int reorderCounter_;
/** The elimination function */
HybridGaussianFactorGraph::Eliminate eliminationFunction_;
public:
/// @name Standard Constructors
/// @{
/**
* Periodically reorder and relinearize
* @param reorderInterval is the number of updates between reorderings,
* 0 never reorders (and is dangerous for memory consumption)
* 1 (default) reorders every time, in worse case is batch every update
* typical values are 50 or 100
*/
HybridNonlinearISAM(
int reorderInterval = 1,
const HybridGaussianFactorGraph::Eliminate& eliminationFunction =
HybridGaussianFactorGraph::EliminationTraitsType::DefaultEliminate)
: reorderInterval_(reorderInterval),
reorderCounter_(0),
eliminationFunction_(eliminationFunction) {}
/// @}
/// @name Standard Interface
/// @{
/** Return the current solution estimate */
Values estimate();
// /** find the marginal covariance for a single variable */
// Matrix marginalCovariance(Key key) const;
// access
/** access the underlying bayes tree */
const HybridGaussianISAM& bayesTree() const { return isam_; }
/**
* @brief Prune the underlying Bayes tree.
*
* @param maxNumberLeaves The max number of leaf nodes to keep.
*/
void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); }
/** Return the current linearization point */
const Values& getLinearizationPoint() const { return linPoint_; }
/** Return the current discrete assignment */
const DiscreteValues& getAssignment() const { return assignment_; }
/** get underlying nonlinear graph */
const HybridNonlinearFactorGraph& getFactorsUnsafe() const {
return factors_;
}
/** get counters */
int reorderInterval() const { return reorderInterval_; } ///< TODO: comment
int reorderCounter() const { return reorderCounter_; } ///< TODO: comment
/** prints out all contents of the system */
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/** prints out clique statistics */
void printStats() const;
/** saves the Tree to a text file in GraphViz format */
void saveGraph(const std::string& s,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
/// @name Advanced Interface
/// @{
/** Add new factors along with their initial linearization points */
void update(const HybridNonlinearFactorGraph& newFactors,
const Values& initialValues,
const boost::optional<size_t>& maxNrLeaves = boost::none,
const boost::optional<Ordering>& ordering = boost::none);
/** Relinearization and reordering of variables */
void reorder_relinearize();
/// @}
};
} // namespace gtsam

View File

@ -31,60 +31,78 @@
namespace gtsam { namespace gtsam {
/** /**
* HybridValues represents a collection of DiscreteValues and VectorValues. It * HybridValues represents a collection of DiscreteValues and VectorValues.
* is typically used to store the variables of a HybridGaussianFactorGraph. * It is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class. * Optimizing a HybridGaussianBayesNet returns this class.
*/ */
class GTSAM_EXPORT HybridValues { class GTSAM_EXPORT HybridValues {
public: private:
// DiscreteValue stored the discrete components of the HybridValues. // DiscreteValue stored the discrete components of the HybridValues.
DiscreteValues discrete; DiscreteValues discrete_;
// VectorValue stored the continuous components of the HybridValues. // VectorValue stored the continuous components of the HybridValues.
VectorValues continuous; VectorValues continuous_;
// Default constructor creates an empty HybridValues. public:
HybridValues() : discrete(), continuous(){}; /// @name Standard Constructors
/// @{
// Construct from DiscreteValues and VectorValues. /// Default constructor creates an empty HybridValues.
HybridValues() = default;
/// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv) HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){}; : discrete_(dv), continuous_(cv){};
// print required by Testable for unit testing /// @}
/// @name Testable
/// @{
/// print required by Testable for unit testing
void print(const std::string& s = "HybridValues", void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n"; std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components discrete_.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous", continuous_.print(" Continuous",
keyFormatter); // print continuous components keyFormatter); // print continuous components
}; };
// equals required by Testable for unit testing /// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const { bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) && return discrete_.equals(other.discrete_, tol) &&
continuous.equals(other.continuous, tol); continuous_.equals(other.continuous_, tol);
} }
// Check whether a variable with key \c j exists in DiscreteValue. /// @}
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; /// @name Interface
/// @{
// Check whether a variable with key \c j exists in VectorValue. /// Return the discrete MPE assignment
bool existsVector(Key j) { return continuous.exists(j); }; DiscreteValues discrete() const { return discrete_; }
// Check whether a variable with key \c j exists. /// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }
/// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
/// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous_.exists(j); };
/// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
/** Insert a discrete \c value with key \c j. Replaces the existing value if /** Insert a discrete \c value with key \c j. Replaces the existing value if
* the key \c j is already used. * the key \c j is already used.
* @param value The vector to be inserted. * @param value The vector to be inserted.
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, int value) { discrete[j] = value; }; void insert(Key j, int value) { discrete_[j] = value; };
/** Insert a vector \c value with key \c j. Throws an invalid_argument /** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used. * exception if the key \c j is already used.
* @param value The vector to be inserted. * @param value The vector to be inserted.
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous.insert(j, value); } void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
@ -92,13 +110,13 @@ class GTSAM_EXPORT HybridValues {
* Read/write access to the discrete value with key \c j, throws * Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
size_t& atDiscrete(Key j) { return discrete.at(j); }; size_t& atDiscrete(Key j) { return discrete_.at(j); };
/** /**
* Read/write access to the vector value with key \c j, throws * Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
Vector& at(Key j) { return continuous.at(j); }; Vector& at(Key j) { return continuous_.at(j); };
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -112,8 +130,8 @@ class GTSAM_EXPORT HybridValues {
std::string html( std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss; std::stringstream ss;
ss << this->discrete.html(keyFormatter); ss << this->discrete_.html(keyFormatter);
ss << this->continuous.html(keyFormatter); ss << this->continuous_.html(keyFormatter);
return ss.str(); return ss.str();
}; };

View File

@ -100,11 +100,23 @@ class MixtureFactor : public HybridFactor {
bool normalized = false) bool normalized = false)
: Base(keys, discreteKeys), normalized_(normalized) { : Base(keys, discreteKeys), normalized_(normalized) {
std::vector<NonlinearFactor::shared_ptr> nonlinear_factors; std::vector<NonlinearFactor::shared_ptr> nonlinear_factors;
KeySet continuous_keys_set(keys.begin(), keys.end());
KeySet factor_keys_set;
for (auto&& f : factors) { for (auto&& f : factors) {
// Insert all factor continuous keys in the continuous keys set.
std::copy(f->keys().begin(), f->keys().end(),
std::inserter(factor_keys_set, factor_keys_set.end()));
nonlinear_factors.push_back( nonlinear_factors.push_back(
boost::dynamic_pointer_cast<NonlinearFactor>(f)); boost::dynamic_pointer_cast<NonlinearFactor>(f));
} }
factors_ = Factors(discreteKeys, nonlinear_factors); factors_ = Factors(discreteKeys, nonlinear_factors);
if (continuous_keys_set != factor_keys_set) {
throw std::runtime_error(
"The specified continuous keys and the keys in the factors don't "
"match!");
}
} }
~MixtureFactor() = default; ~MixtureFactor() = default;

View File

@ -6,8 +6,8 @@ namespace gtsam {
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
class HybridValues { class HybridValues {
gtsam::DiscreteValues discrete; gtsam::DiscreteValues discrete() const;
gtsam::VectorValues continuous; gtsam::VectorValues continuous() const;
HybridValues(); HybridValues();
HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv); HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv);
void print(string s = "HybridValues", void print(string s = "HybridValues",
@ -99,6 +99,8 @@ class HybridBayesTree {
bool empty() const; bool empty() const;
const HybridBayesTreeClique* operator[](size_t j) const; const HybridBayesTreeClique* operator[](size_t j) const;
gtsam::HybridValues optimize() const;
string dot(const gtsam::KeyFormatter& keyFormatter = string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };

View File

@ -115,7 +115,6 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
/* *************************************************************************** /* ***************************************************************************
*/ */
using MotionModel = BetweenFactor<double>; using MotionModel = BetweenFactor<double>;
// using MotionMixture = MixtureFactor<MotionModel>;
// Test fixture with switching network. // Test fixture with switching network.
struct Switching { struct Switching {
@ -125,12 +124,15 @@ struct Switching {
HybridGaussianFactorGraph linearizedFactorGraph; HybridGaussianFactorGraph linearizedFactorGraph;
Values linearizationPoint; Values linearizationPoint;
/// Create with given number of time steps. /**
* @brief Create with given number of time steps.
*
* @param K The total number of timesteps.
* @param between_sigma The stddev between poses.
* @param prior_sigma The stddev on priors (also used for measurements).
*/
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1) Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1)
: K(K) { : K(K) {
using symbol_shorthand::M;
using symbol_shorthand::X;
// Create DiscreteKeys for binary K modes, modes[0] will not be used. // Create DiscreteKeys for binary K modes, modes[0] will not be used.
for (size_t k = 0; k <= K; k++) { for (size_t k = 0; k <= K; k++) {
modes.emplace_back(M(k), 2); modes.emplace_back(M(k), 2);
@ -145,7 +147,7 @@ struct Switching {
// Add "motion models". // Add "motion models".
for (size_t k = 1; k < K; k++) { for (size_t k = 1; k < K; k++) {
KeyVector keys = {X(k), X(k + 1)}; KeyVector keys = {X(k), X(k + 1)};
auto motion_models = motionModels(k); auto motion_models = motionModels(k, between_sigma);
std::vector<NonlinearFactor::shared_ptr> components; std::vector<NonlinearFactor::shared_ptr> components;
for (auto &&f : motion_models) { for (auto &&f : motion_models) {
components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f)); components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f));
@ -155,7 +157,7 @@ struct Switching {
} }
// Add measurement factors // Add measurement factors
auto measurement_noise = noiseModel::Isotropic::Sigma(1, 0.1); auto measurement_noise = noiseModel::Isotropic::Sigma(1, prior_sigma);
for (size_t k = 2; k <= K; k++) { for (size_t k = 2; k <= K; k++) {
nonlinearFactorGraph.emplace_nonlinear<PriorFactor<double>>( nonlinearFactorGraph.emplace_nonlinear<PriorFactor<double>>(
X(k), 1.0 * (k - 1), measurement_noise); X(k), 1.0 * (k - 1), measurement_noise);
@ -169,15 +171,14 @@ struct Switching {
linearizationPoint.insert<double>(X(k), static_cast<double>(k)); linearizationPoint.insert<double>(X(k), static_cast<double>(k));
} }
linearizedFactorGraph = nonlinearFactorGraph.linearize(linearizationPoint); // The ground truth is robot moving forward
// and one less than the linearization point
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
} }
// Create motion models for a given time step // Create motion models for a given time step
static std::vector<MotionModel::shared_ptr> motionModels(size_t k, static std::vector<MotionModel::shared_ptr> motionModels(size_t k,
double sigma = 1.0) { double sigma = 1.0) {
using symbol_shorthand::M;
using symbol_shorthand::X;
auto noise_model = noiseModel::Isotropic::Sigma(1, sigma); auto noise_model = noiseModel::Isotropic::Sigma(1, sigma);
auto still = auto still =
boost::make_shared<MotionModel>(X(k), X(k + 1), 0.0, noise_model), boost::make_shared<MotionModel>(X(k), X(k + 1), 0.0, noise_model),

View File

@ -18,7 +18,10 @@
* @date December 2021 * @date December 2021
*/ */
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include "Switching.h" #include "Switching.h"
@ -27,6 +30,8 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
using noiseModel::Isotropic; using noiseModel::Isotropic;
using symbol_shorthand::M; using symbol_shorthand::M;
using symbol_shorthand::X; using symbol_shorthand::X;
@ -47,6 +52,20 @@ TEST(HybridBayesNet, Creation) {
EXPECT(df.equals(expected)); EXPECT(df.equals(expected));
} }
/* ****************************************************************************/
// Test adding a bayes net to another one.
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
DiscreteConditional expected(Asia, "99/1");
HybridBayesNet other;
other.push_back(bayesNet);
EXPECT(bayesNet.equals(other));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test choosing an assignment of conditionals // Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) { TEST(HybridBayesNet, Choose) {
@ -72,19 +91,128 @@ TEST(HybridBayesNet, Choose) {
EXPECT_LONGS_EQUAL(4, gbn.size()); EXPECT_LONGS_EQUAL(4, gbn.size());
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(0)))(assignment), hybridBayesNet->atMixture(0)))(assignment),
*gbn.at(0))); *gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(1)))(assignment), hybridBayesNet->atMixture(1)))(assignment),
*gbn.at(1))); *gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(2)))(assignment), hybridBayesNet->atMixture(2)))(assignment),
*gbn.at(2))); *gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(3)))(assignment), hybridBayesNet->atMixture(3)))(assignment),
*gbn.at(3))); *gbn.at(3)));
} }
/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) {
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 1;
VectorValues delta = hybridBayesNet->optimize(assignment);
// The linearization point has the same value as the key index,
// e.g. X(1) = 1, X(2) = 2,
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
EXPECT(assert_equal(expected_delta, delta));
}
/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, Optimize) {
Switching s(4);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
DiscreteValues expectedAssignment;
expectedAssignment[M(1)] = 1;
expectedAssignment[M(2)] = 0;
expectedAssignment[M(3)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
VectorValues expectedValues;
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ****************************************************************************/
// Test bayes net multifrontal optimize
TEST(HybridBayesNet, OptimizeMultifrontal) {
Switching s(4);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree::shared_ptr hybridBayesTree =
s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering);
HybridValues delta = hybridBayesTree->optimize();
VectorValues expectedValues;
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ****************************************************************************/
// Test bayes net pruning
TEST(HybridBayesNet, Prune) {
Switching s(4);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
auto prunedBayesNet = hybridBayesNet->prune(2);
HybridValues pruned_delta = prunedBayesNet.optimize();
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
}
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -0,0 +1,166 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridBayesTree.cpp
* @brief Unit tests for HybridBayesTree
* @author Varun Agrawal
* @date August 2022
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
#include "Switching.h"
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
/* ****************************************************************************/
// Test for optimizing a HybridBayesTree with a given assignment.
TEST(HybridBayesTree, OptimizeAssignment) {
Switching s(4);
HybridGaussianISAM isam;
HybridGaussianFactorGraph graph1;
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(1),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(s.linearizedFactorGraph.at(0));
for (size_t i = 4; i <= 7; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the discrete factors
for (size_t i = 7; i <= 9; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
isam.update(graph1);
DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 1;
VectorValues delta = isam.optimize(assignment);
// The linearization point has the same value as the key index,
// e.g. X(1) = 1, X(2) = 2,
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
EXPECT(assert_equal(expected_delta, delta));
// Create ordering.
Ordering ordering;
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
GaussianBayesNet gbn = hybridBayesNet->choose(assignment);
VectorValues expected = gbn.optimize();
EXPECT(assert_equal(expected, delta));
}
/* ****************************************************************************/
// Test for optimizing a HybridBayesTree.
TEST(HybridBayesTree, Optimize) {
Switching s(4);
HybridGaussianISAM isam;
HybridGaussianFactorGraph graph1;
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(1),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(s.linearizedFactorGraph.at(0));
for (size_t i = 4; i <= 6; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the discrete factors
for (size_t i = 7; i <= 9; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
isam.update(graph1);
HybridValues delta = isam.optimize();
// Create ordering.
Ordering ordering;
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
dfg.push_back(
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
}
DiscreteValues expectedMPE = dfg.optimize();
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
EXPECT(assert_equal(expectedMPE, delta.discrete()));
EXPECT(assert_equal(expectedValues, delta.continuous()));
}
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
using namespace gtsam::serializationTestHelpers;
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
hfg.add(DecisionTreeFactor(m1, {2, 8})); hfg.add(DecisionTreeFactor(m1, {2, 8}));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal( HybridBayesTree::shared_ptr result =
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)})); hfg.eliminateMultifrontal(hfg.getHybridOrdering());
// The bayes tree should have 3 cliques // The bayes tree should have 3 cliques
EXPECT_LONGS_EQUAL(3, result->size()); EXPECT_LONGS_EQUAL(3, result->size());
@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
// Get a constrained ordering keeping c1 last // Get a constrained ordering keeping c1 last
auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)}); auto ordering_full = hfg.getHybridOrdering();
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1) // Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full); HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
} }
HybridBayesNet::shared_ptr hbn; HybridBayesNet::shared_ptr hbn;
HybridGaussianFactorGraph::shared_ptr remaining; HybridGaussianFactorGraph::shared_ptr remaining;
std::tie(hbn, remaining) = std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial);
hfg->eliminatePartialSequential(ordering_partial);
EXPECT_LONGS_EQUAL(14, hbn->size()); EXPECT_LONGS_EQUAL(14, hbn->size());
EXPECT_LONGS_EQUAL(11, remaining->size()); EXPECT_LONGS_EQUAL(11, remaining->size());
@ -501,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
} }
} }
/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, optimize) { TEST(HybridGaussianFactorGraph, optimize) {
HybridGaussianFactorGraph hfg; HybridGaussianFactorGraph hfg;
@ -522,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) {
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
} }
/* ************************************************************************* */
// Test adding of gaussian conditional and re-elimination.
TEST(HybridGaussianFactorGraph, Conditionals) {
Switching switching(4);
HybridGaussianFactorGraph hfg;
hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
Ordering ordering;
ordering.push_back(X(1));
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);
hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
hfg.push_back(*bayes_net);
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
ordering.push_back(X(2));
ordering.push_back(X(3));
ordering.push_back(M(1));
ordering.push_back(M(2));
bayes_net = hfg.eliminateSequential(ordering);
HybridValues result = bayes_net->optimize();
Values expected_continuous;
expected_continuous.insert<double>(X(1), 0);
expected_continuous.insert<double>(X(2), 1);
expected_continuous.insert<double>(X(3), 2);
expected_continuous.insert<double>(X(4), 4);
Values result_continuous =
switching.linearizationPoint.retract(result.continuous());
EXPECT(assert_equal(expected_continuous, result_continuous));
DiscreteValues expected_discrete;
expected_discrete[M(1)] = 1;
expected_discrete[M(2)] = 1;
EXPECT(assert_equal(expected_discrete, result.discrete()));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
size_t maxNrLeaves = 5; size_t maxNrLeaves = 5;
incrementalHybrid.update(graph1); incrementalHybrid.update(graph1);
incrementalHybrid.prune(M(3), maxNrLeaves); incrementalHybrid.prune(maxNrLeaves);
/* /*
unpruned factor is: unpruned factor is:
@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
// Run update with pruning // Run update with pruning
size_t maxComponents = 5; size_t maxComponents = 5;
incrementalHybrid.update(graph1); incrementalHybrid.update(graph1);
incrementalHybrid.prune(M(3), maxComponents); incrementalHybrid.prune(maxComponents);
// Check if we have a bayes tree with 4 hybrid nodes, // Check if we have a bayes tree with 4 hybrid nodes,
// each with 2, 4, 8, and 5 (pruned) leaves respetively. // each with 2, 4, 8, and 5 (pruned) leaves respetively.
@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); 2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); 3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); 5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
// Run update with pruning a second time. // Run update with pruning a second time.
incrementalHybrid.update(graph2); incrementalHybrid.update(graph2);
incrementalHybrid.prune(M(4), maxComponents); incrementalHybrid.prune(maxComponents);
// Check if we have a bayes tree with pruned hybrid nodes, // Check if we have a bayes tree with pruned hybrid nodes,
// with 5 (pruned) leaves. // with 5 (pruned) leaves.
@ -399,7 +399,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
initial.insert(Z(0), Pose2(0.0, 2.0, 0.0)); initial.insert(Z(0), Pose2(0.0, 2.0, 0.0));
initial.insert(W(0), Pose2(0.0, 3.0, 0.0)); initial.insert(W(0), Pose2(0.0, 3.0, 0.0));
HybridGaussianFactorGraph gfg = fg.linearize(initial); HybridGaussianFactorGraph gfg = *fg.linearize(initial);
fg = HybridNonlinearFactorGraph(); fg = HybridNonlinearFactorGraph();
HybridGaussianISAM inc; HybridGaussianISAM inc;
@ -444,7 +444,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// The leg link did not move so we set the expected pose accordingly. // The leg link did not move so we set the expected pose accordingly.
initial.insert(W(1), Pose2(0.0, 3.0, 0.0)); initial.insert(W(1), Pose2(0.0, 3.0, 0.0));
gfg = fg.linearize(initial); gfg = *fg.linearize(initial);
fg = HybridNonlinearFactorGraph(); fg = HybridNonlinearFactorGraph();
// Update without pruning // Update without pruning
@ -483,7 +483,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
initial.insert(Z(2), Pose2(2.0, 2.0, 0.0)); initial.insert(Z(2), Pose2(2.0, 2.0, 0.0));
initial.insert(W(2), Pose2(0.0, 3.0, 0.0)); initial.insert(W(2), Pose2(0.0, 3.0, 0.0));
gfg = fg.linearize(initial); gfg = *fg.linearize(initial);
fg = HybridNonlinearFactorGraph(); fg = HybridNonlinearFactorGraph();
// Now we prune! // Now we prune!
@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// The MHS at this point should be a 2 level tree on (1, 2). // The MHS at this point should be a 2 level tree on (1, 2).
// 1 has 2 choices, and 2 has 4 choices. // 1 has 2 choices, and 2 has 4 choices.
inc.update(gfg); inc.update(gfg);
inc.prune(M(2), 2); inc.prune(2);
/*************** Run Round 4 ***************/ /*************** Run Round 4 ***************/
// Add odometry factor with discrete modes. // Add odometry factor with discrete modes.
@ -526,12 +526,12 @@ TEST(HybridGaussianISAM, NonTrivial) {
initial.insert(Z(3), Pose2(3.0, 2.0, 0.0)); initial.insert(Z(3), Pose2(3.0, 2.0, 0.0));
initial.insert(W(3), Pose2(0.0, 3.0, 0.0)); initial.insert(W(3), Pose2(0.0, 3.0, 0.0));
gfg = fg.linearize(initial); gfg = *fg.linearize(initial);
fg = HybridNonlinearFactorGraph(); fg = HybridNonlinearFactorGraph();
// Keep pruning! // Keep pruning!
inc.update(gfg); inc.update(gfg);
inc.prune(M(3), 3); inc.prune(3);
// The final discrete graph should not be empty since we have eliminated // The final discrete graph should not be empty since we have eliminated
// all continuous variables. // all continuous variables.

View File

@ -1,272 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
#include <iostream>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
TEST(HybridLookupTable, basics) {
// create a conditional gaussian node
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 2;
S1(0, 1) = 3;
S1(1, 1) = 4;
Matrix S2(2, 2);
S2(0, 0) = 6;
S2(1, 0) = 0.2;
S2(0, 1) = 8;
S2(1, 1) = 0.4;
Matrix R1(2, 2);
R1(0, 0) = 0.1;
R1(1, 0) = 0.3;
R1(0, 1) = 0.0;
R1(1, 1) = 0.34;
Matrix R2(2, 2);
R2(0, 0) = 0.1;
R2(1, 0) = 0.3;
R2(0, 1) = 0.0;
R2(1, 1) = 0.34;
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
Vector2 d1(0.2, 0.5), d2(0.5, 0.2);
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
X(2), S1, model),
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);
// Create decision tree
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
HybridConditional hc(mixtureFactor);
GaussianMixture::Conditionals conditional2 =
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
DiscreteValues dv;
dv[1] = 1;
VectorValues cv;
cv.insert(X(2), Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
// std::cout << conditional2(values).markdown();
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
EXPECT(conditional2(dv) == conditionals(dv));
HybridLookupTable hlt(hc);
// hlt.argmaxInPlace(&hv);
HybridLookupDAG dag;
dag.push_back(hlt);
dag.argmax(hv);
// HybridBayesNet hbn;
// hbn.push_back(hc);
// hbn.optimize();
}
TEST(HybridLookupTable, hybrid_argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 =
boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
HybridConditional hc(mixtureFactor);
DiscreteValues dv;
dv[1] = 1;
VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
HybridLookupTable hlt(hc);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d2));
}
TEST(HybridLookupTable, discrete_argmax) {
DiscreteKey X(0, 2), Y(1, 2);
auto conditional = boost::make_shared<DiscreteConditional>(X | Y = "0/1 3/2");
HybridConditional hc(conditional);
HybridLookupTable hlt(hc);
DiscreteValues dv;
dv[1] = 0;
VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.atDiscrete(0), 1));
DecisionTreeFactor f1(X, "2 3");
auto conditional2 = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc2(conditional2);
HybridLookupTable hlt2(hc2);
HybridValues hv2;
hlt2.argmaxInPlace(&hv2);
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
}
TEST(HybridLookupTable, gaussian_argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional =
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc(conditional);
HybridLookupTable hlt(hc);
DiscreteValues dv;
// dv[1]=0;
VectorValues cv;
cv.insert(X(2), d2);
HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
}
TEST(HybridLookupDAG, argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 =
boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
HybridConditional hc2(mixtureFactor);
HybridLookupTable hlt2(hc2);
auto conditional2 =
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc1(conditional2);
HybridLookupTable hlt1(hc1);
DecisionTreeFactor f1(m1, "2 3");
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc3(discrete_conditional);
HybridLookupTable hlt3(hc3);
HybridLookupDAG dag;
dag.push_back(hlt1);
dag.push_back(hlt2);
dag.push_back(hlt3);
auto hv = dag.argmax();
EXPECT(assert_equal(hv.atDiscrete(1), 1));
EXPECT(assert_equal(hv.at(X(2)), d2));
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -60,7 +60,7 @@ TEST(HybridFactorGraph, GaussianFactorGraph) {
Values linearizationPoint; Values linearizationPoint;
linearizationPoint.insert<double>(X(0), 0); linearizationPoint.insert<double>(X(0), 0);
HybridGaussianFactorGraph ghfg = fg.linearize(linearizationPoint); HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint);
// Add a factor to the GaussianFactorGraph // Add a factor to the GaussianFactorGraph
ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5))); ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5)));
@ -139,7 +139,7 @@ TEST(HybridGaussianFactorGraph, Resize) {
linearizationPoint.insert<double>(X(1), 1); linearizationPoint.insert<double>(X(1), 1);
// Generate `HybridGaussianFactorGraph` by linearizing // Generate `HybridGaussianFactorGraph` by linearizing
HybridGaussianFactorGraph gfg = nhfg.linearize(linearizationPoint); HybridGaussianFactorGraph gfg = *nhfg.linearize(linearizationPoint);
EXPECT_LONGS_EQUAL(gfg.size(), 3); EXPECT_LONGS_EQUAL(gfg.size(), 3);
@ -147,6 +147,32 @@ TEST(HybridGaussianFactorGraph, Resize) {
EXPECT_LONGS_EQUAL(gfg.size(), 0); EXPECT_LONGS_EQUAL(gfg.size(), 0);
} }
/***************************************************************************
* Test that the MixtureFactor reports correctly if the number of continuous
* keys provided do not match the keys in the factors.
*/
TEST(HybridGaussianFactorGraph, MixtureFactor) {
auto nonlinearFactor = boost::make_shared<BetweenFactor<double>>(
X(0), X(1), 0.0, Isotropic::Sigma(1, 0.1));
auto discreteFactor = boost::make_shared<DecisionTreeFactor>();
auto noise_model = noiseModel::Isotropic::Sigma(1, 1.0);
auto still = boost::make_shared<MotionModel>(X(0), X(1), 0.0, noise_model),
moving = boost::make_shared<MotionModel>(X(0), X(1), 1.0, noise_model);
std::vector<MotionModel::shared_ptr> components = {still, moving};
// Check for exception when number of continuous keys are under-specified.
KeyVector contKeys = {X(0)};
THROWS_EXCEPTION(boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components));
// Check for exception when number of continuous keys are too many.
contKeys = {X(0), X(1), X(2)};
THROWS_EXCEPTION(boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components));
}
/***************************************************************************** /*****************************************************************************
* Test push_back on HFG makes the correct distinction. * Test push_back on HFG makes the correct distinction.
*/ */
@ -224,7 +250,7 @@ TEST(HybridFactorGraph, Linearization) {
// Linearize here: // Linearize here:
HybridGaussianFactorGraph actualLinearized = HybridGaussianFactorGraph actualLinearized =
self.nonlinearFactorGraph.linearize(self.linearizationPoint); *self.nonlinearFactorGraph.linearize(self.linearizationPoint);
EXPECT_LONGS_EQUAL(7, actualLinearized.size()); EXPECT_LONGS_EQUAL(7, actualLinearized.size());
} }
@ -257,14 +283,6 @@ TEST(GaussianElimination, Eliminate_x1) {
// Add first hybrid factor // Add first hybrid factor
factors.push_back(self.linearizedFactorGraph[1]); factors.push_back(self.linearizedFactorGraph[1]);
// TODO(Varun) remove this block since sum is no longer exposed.
// // Check that sum works:
// auto sum = factors.sum();
// Assignment<Key> mode;
// mode[M(1)] = 1;
// auto actual = sum(mode); // Selects one of 2 modes.
// EXPECT_LONGS_EQUAL(2, actual.size()); // Prior and motion model.
// Eliminate x1 // Eliminate x1
Ordering ordering; Ordering ordering;
ordering += X(1); ordering += X(1);
@ -289,15 +307,6 @@ TEST(HybridsGaussianElimination, Eliminate_x2) {
factors.push_back(self.linearizedFactorGraph[1]); // involves m1 factors.push_back(self.linearizedFactorGraph[1]); // involves m1
factors.push_back(self.linearizedFactorGraph[2]); // involves m2 factors.push_back(self.linearizedFactorGraph[2]); // involves m2
// TODO(Varun) remove this block since sum is no longer exposed.
// // Check that sum works:
// auto sum = factors.sum();
// Assignment<Key> mode;
// mode[M(1)] = 0;
// mode[M(2)] = 1;
// auto actual = sum(mode); // Selects one of 4 mode
// combinations. EXPECT_LONGS_EQUAL(2, actual.size()); // 2 motion models.
// Eliminate x2 // Eliminate x2
Ordering ordering; Ordering ordering;
ordering += X(2); ordering += X(2);
@ -364,51 +373,10 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
CHECK(discreteFactor); CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false); EXPECT(discreteFactor->root_->isLeaf() == false);
// TODO(Varun) Test emplace_discrete
} }
// /*
// ****************************************************************************/
// /// Test the toDecisionTreeFactor method
// TEST(HybridFactorGraph, ToDecisionTreeFactor) {
// size_t K = 3;
// // Provide tight sigma values so that the errors are visibly different.
// double between_sigma = 5e-8, prior_sigma = 1e-7;
// Switching self(K, between_sigma, prior_sigma);
// // Clear out discrete factors since sum() cannot hanldle those
// HybridGaussianFactorGraph linearizedFactorGraph(
// self.linearizedFactorGraph.gaussianGraph(), DiscreteFactorGraph(),
// self.linearizedFactorGraph.dcGraph());
// auto decisionTreeFactor = linearizedFactorGraph.toDecisionTreeFactor();
// auto allAssignments =
// DiscreteValues::CartesianProduct(linearizedFactorGraph.discreteKeys());
// // Get the error of the discrete assignment m1=0, m2=1.
// double actual = (*decisionTreeFactor)(allAssignments[1]);
// /********************************************/
// // Create equivalent factor graph for m1=0, m2=1
// GaussianFactorGraph graph = linearizedFactorGraph.gaussianGraph();
// for (auto &p : linearizedFactorGraph.dcGraph()) {
// if (auto mixture =
// boost::dynamic_pointer_cast<DCGaussianMixtureFactor>(p)) {
// graph.add((*mixture)(allAssignments[1]));
// }
// }
// VectorValues values = graph.optimize();
// double expected = graph.probPrime(values);
// /********************************************/
// EXPECT_DOUBLES_EQUAL(expected, actual, 1e-12);
// // REGRESSION:
// EXPECT_DOUBLES_EQUAL(0.6125, actual, 1e-4);
// }
/**************************************************************************** /****************************************************************************
* Test partial elimination * Test partial elimination
*/ */
@ -428,7 +396,6 @@ TEST(HybridFactorGraph, Partial_Elimination) {
linearizedFactorGraph.eliminatePartialSequential(ordering); linearizedFactorGraph.eliminatePartialSequential(ordering);
CHECK(hybridBayesNet); CHECK(hybridBayesNet);
// GTSAM_PRINT(*hybridBayesNet); // HybridBayesNet
EXPECT_LONGS_EQUAL(3, hybridBayesNet->size()); EXPECT_LONGS_EQUAL(3, hybridBayesNet->size());
EXPECT(hybridBayesNet->at(0)->frontals() == KeyVector{X(1)}); EXPECT(hybridBayesNet->at(0)->frontals() == KeyVector{X(1)});
EXPECT(hybridBayesNet->at(0)->parents() == KeyVector({X(2), M(1)})); EXPECT(hybridBayesNet->at(0)->parents() == KeyVector({X(2), M(1)}));
@ -438,7 +405,6 @@ TEST(HybridFactorGraph, Partial_Elimination) {
EXPECT(hybridBayesNet->at(2)->parents() == KeyVector({M(1), M(2)})); EXPECT(hybridBayesNet->at(2)->parents() == KeyVector({M(1), M(2)}));
CHECK(remainingFactorGraph); CHECK(remainingFactorGraph);
// GTSAM_PRINT(*remainingFactorGraph); // HybridFactorGraph
EXPECT_LONGS_EQUAL(3, remainingFactorGraph->size()); EXPECT_LONGS_EQUAL(3, remainingFactorGraph->size());
EXPECT(remainingFactorGraph->at(0)->keys() == KeyVector({M(1)})); EXPECT(remainingFactorGraph->at(0)->keys() == KeyVector({M(1)}));
EXPECT(remainingFactorGraph->at(1)->keys() == KeyVector({M(2), M(1)})); EXPECT(remainingFactorGraph->at(1)->keys() == KeyVector({M(2), M(1)}));
@ -721,13 +687,8 @@ TEST(HybridFactorGraph, DefaultDecisionTree) {
moving = boost::make_shared<PlanarMotionModel>(X(0), X(1), odometry, moving = boost::make_shared<PlanarMotionModel>(X(0), X(1), odometry,
noise_model); noise_model);
std::vector<PlanarMotionModel::shared_ptr> motion_models = {still, moving}; std::vector<PlanarMotionModel::shared_ptr> motion_models = {still, moving};
// TODO(Varun) Make a templated constructor for MixtureFactor which does this?
std::vector<NonlinearFactor::shared_ptr> components;
for (auto&& f : motion_models) {
components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f));
}
fg.emplace_hybrid<MixtureFactor>( fg.emplace_hybrid<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components); contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, motion_models);
// Add Range-Bearing measurements to from X0 to L0 and X1 to L1. // Add Range-Bearing measurements to from X0 to L0 and X1 to L1.
// create a noise model for the landmark measurements // create a noise model for the landmark measurements
@ -757,7 +718,7 @@ TEST(HybridFactorGraph, DefaultDecisionTree) {
ordering += X(0); ordering += X(0);
ordering += X(1); ordering += X(1);
HybridGaussianFactorGraph linearized = fg.linearize(initialEstimate); HybridGaussianFactorGraph linearized = *fg.linearize(initialEstimate);
gtsam::HybridBayesNet::shared_ptr hybridBayesNet; gtsam::HybridBayesNet::shared_ptr hybridBayesNet;
gtsam::HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; gtsam::HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;

View File

@ -0,0 +1,586 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridNonlinearISAM.cpp
* @brief Unit tests for nonlinear incremental inference
* @author Varun Agrawal, Fan Jiang, Frank Dellaert
* @date Jan 2021
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/PriorFactor.h>
#include <gtsam/sam/BearingRangeFactor.h>
#include <numeric>
#include "Switching.h"
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::L;
using symbol_shorthand::M;
using symbol_shorthand::W;
using symbol_shorthand::X;
using symbol_shorthand::Y;
using symbol_shorthand::Z;
/* ****************************************************************************/
// Test if we can perform elimination incrementally.
TEST(HybridNonlinearISAM, IncrementalElimination) {
Switching switching(3);
HybridNonlinearISAM isam;
HybridNonlinearFactorGraph graph1;
Values initial;
// Create initial factor graph
// * * *
// | | |
// X1 -*- X2 -*- X3
// \*-M1-*/
graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X1)
graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X1, X2 | M1)
graph1.push_back(switching.nonlinearFactorGraph.at(2)); // P(X2, X3 | M2)
graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M1)
initial.insert<double>(X(1), 1);
initial.insert<double>(X(2), 2);
initial.insert<double>(X(3), 3);
// Run update step
isam.update(graph1, initial);
// Check that after update we have 3 hybrid Bayes net nodes:
// P(X1 | X2, M1) and P(X2, X3 | M1, M2), P(M1, M2)
HybridGaussianISAM bayesTree = isam.bayesTree();
EXPECT_LONGS_EQUAL(3, bayesTree.size());
EXPECT(bayesTree[X(1)]->conditional()->frontals() == KeyVector{X(1)});
EXPECT(bayesTree[X(1)]->conditional()->parents() == KeyVector({X(2), M(1)}));
EXPECT(bayesTree[X(2)]->conditional()->frontals() == KeyVector({X(2), X(3)}));
EXPECT(bayesTree[X(2)]->conditional()->parents() == KeyVector({M(1), M(2)}));
/********************************************************/
// New factor graph for incremental update.
HybridNonlinearFactorGraph graph2;
initial = Values();
graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X2)
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X3)
graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M1, M2)
isam.update(graph2, initial);
bayesTree = isam.bayesTree();
// Check that after the second update we have
// 1 additional hybrid Bayes net node:
// P(X2, X3 | M1, M2)
EXPECT_LONGS_EQUAL(3, bayesTree.size());
EXPECT(bayesTree[X(3)]->conditional()->frontals() == KeyVector({X(2), X(3)}));
EXPECT(bayesTree[X(3)]->conditional()->parents() == KeyVector({M(1), M(2)}));
}
/* ****************************************************************************/
// Test if we can incrementally do the inference
TEST(HybridNonlinearISAM, IncrementalInference) {
Switching switching(3);
HybridNonlinearISAM isam;
HybridNonlinearFactorGraph graph1;
Values initial;
// Create initial factor graph
// * * *
// | | |
// X1 -*- X2 -*- X3
// | |
// *-M1 - * - M2
graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X1)
graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X1, X2 | M1)
graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X2)
graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M1)
initial.insert<double>(X(1), 1);
initial.insert<double>(X(2), 2);
// Run update step
isam.update(graph1, initial);
HybridGaussianISAM bayesTree = isam.bayesTree();
auto discreteConditional_m1 =
bayesTree[M(1)]->conditional()->asDiscreteConditional();
EXPECT(discreteConditional_m1->keys() == KeyVector({M(1)}));
/********************************************************/
// New factor graph for incremental update.
HybridNonlinearFactorGraph graph2;
initial = Values();
initial.insert<double>(X(3), 3);
graph2.push_back(switching.nonlinearFactorGraph.at(2)); // P(X2, X3 | M2)
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X3)
graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M1, M2)
isam.update(graph2, initial);
bayesTree = isam.bayesTree();
/********************************************************/
// Run batch elimination so we can compare results.
Ordering ordering;
ordering += X(1);
ordering += X(2);
ordering += X(3);
// Now we calculate the actual factors using full elimination
HybridBayesTree::shared_ptr expectedHybridBayesTree;
HybridGaussianFactorGraph::shared_ptr expectedRemainingGraph;
std::tie(expectedHybridBayesTree, expectedRemainingGraph) =
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
// The densities on X(1) should be the same
auto x1_conditional = dynamic_pointer_cast<GaussianMixture>(
bayesTree[X(1)]->conditional()->inner());
auto actual_x1_conditional = dynamic_pointer_cast<GaussianMixture>(
(*expectedHybridBayesTree)[X(1)]->conditional()->inner());
EXPECT(assert_equal(*x1_conditional, *actual_x1_conditional));
// The densities on X(2) should be the same
auto x2_conditional = dynamic_pointer_cast<GaussianMixture>(
bayesTree[X(2)]->conditional()->inner());
auto actual_x2_conditional = dynamic_pointer_cast<GaussianMixture>(
(*expectedHybridBayesTree)[X(2)]->conditional()->inner());
EXPECT(assert_equal(*x2_conditional, *actual_x2_conditional));
// The densities on X(3) should be the same
auto x3_conditional = dynamic_pointer_cast<GaussianMixture>(
bayesTree[X(3)]->conditional()->inner());
auto actual_x3_conditional = dynamic_pointer_cast<GaussianMixture>(
(*expectedHybridBayesTree)[X(2)]->conditional()->inner());
EXPECT(assert_equal(*x3_conditional, *actual_x3_conditional));
// We only perform manual continuous elimination for 0,0.
// The other discrete probabilities on M(2) are calculated the same way
Ordering discrete_ordering;
discrete_ordering += M(1);
discrete_ordering += M(2);
HybridBayesTree::shared_ptr discreteBayesTree =
expectedRemainingGraph->eliminateMultifrontal(discrete_ordering);
DiscreteValues m00;
m00[M(1)] = 0, m00[M(2)] = 0;
DiscreteConditional decisionTree =
*(*discreteBayesTree)[M(2)]->conditional()->asDiscreteConditional();
double m00_prob = decisionTree(m00);
auto discreteConditional =
bayesTree[M(2)]->conditional()->asDiscreteConditional();
// Test if the probability values are as expected with regression tests.
DiscreteValues assignment;
EXPECT(assert_equal(m00_prob, 0.0619233, 1e-5));
assignment[M(1)] = 0;
assignment[M(2)] = 0;
EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5));
assignment[M(1)] = 1;
assignment[M(2)] = 0;
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
assignment[M(1)] = 0;
assignment[M(2)] = 1;
EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5));
assignment[M(1)] = 1;
assignment[M(2)] = 1;
EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5));
// Check if the clique conditional generated from incremental elimination
// matches that of batch elimination.
auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
auto expectedConditional = dynamic_pointer_cast<DecisionTreeFactor>(
(*expectedChordal)[M(2)]->conditional()->inner());
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
bayesTree[M(2)]->conditional()->inner());
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
}
/* ****************************************************************************/
// Test if we can approximately do the inference
TEST(HybridNonlinearISAM, Approx_inference) {
Switching switching(4);
HybridNonlinearISAM incrementalHybrid;
HybridNonlinearFactorGraph graph1;
Values initial;
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(switching.nonlinearFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(1),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(switching.nonlinearFactorGraph.at(0));
for (size_t i = 4; i <= 7; i++) {
graph1.push_back(switching.nonlinearFactorGraph.at(i));
initial.insert<double>(X(i - 3), i - 3);
}
// Create ordering.
Ordering ordering;
for (size_t j = 1; j <= 4; j++) {
ordering += X(j);
}
// Now we calculate the actual factors using full elimination
HybridBayesTree::shared_ptr unprunedHybridBayesTree;
HybridGaussianFactorGraph::shared_ptr unprunedRemainingGraph;
std::tie(unprunedHybridBayesTree, unprunedRemainingGraph) =
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
size_t maxNrLeaves = 5;
incrementalHybrid.update(graph1, initial);
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxNrLeaves);
/*
unpruned factor is:
Choice(m3)
0 Choice(m2)
0 0 Choice(m1)
0 0 0 Leaf 0.11267528
0 0 1 Leaf 0.18576102
0 1 Choice(m1)
0 1 0 Leaf 0.18754662
0 1 1 Leaf 0.30623871
1 Choice(m2)
1 0 Choice(m1)
1 0 0 Leaf 0.18576102
1 0 1 Leaf 0.30622428
1 1 Choice(m1)
1 1 0 Leaf 0.30623871
1 1 1 Leaf 0.5
pruned factors is:
Choice(m3)
0 Choice(m2)
0 0 Leaf 0
0 1 Choice(m1)
0 1 0 Leaf 0.18754662
0 1 1 Leaf 0.30623871
1 Choice(m2)
1 0 Choice(m1)
1 0 0 Leaf 0
1 0 1 Leaf 0.30622428
1 1 Choice(m1)
1 1 0 Leaf 0.30623871
1 1 1 Leaf 0.5
*/
auto discreteConditional_m1 = *dynamic_pointer_cast<DiscreteConditional>(
bayesTree[M(1)]->conditional()->inner());
EXPECT(discreteConditional_m1.keys() == KeyVector({M(1), M(2), M(3)}));
// Get the number of elements which are greater than 0.
auto count = [](const double &value, int count) {
return value > 0 ? count + 1 : count;
};
// Check that the number of leaves after pruning is 5.
EXPECT_LONGS_EQUAL(5, discreteConditional_m1.fold(count, 0));
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions.
auto &unprunedLastDensity = *dynamic_pointer_cast<GaussianMixture>(
unprunedHybridBayesTree->clique(X(4))->conditional()->inner());
auto &lastDensity = *dynamic_pointer_cast<GaussianMixture>(
bayesTree[X(4)]->conditional()->inner());
std::vector<std::pair<DiscreteValues, double>> assignments =
discreteConditional_m1.enumerate();
// Loop over all assignments and check the pruned components
for (auto &&av : assignments) {
const DiscreteValues &assignment = av.first;
const double value = av.second;
if (value == 0.0) {
EXPECT(lastDensity(assignment) == nullptr);
} else {
CHECK(lastDensity(assignment));
EXPECT(assert_equal(*unprunedLastDensity(assignment),
*lastDensity(assignment)));
}
}
}
/* ****************************************************************************/
// Test approximate inference with an additional pruning step.
TEST(HybridNonlinearISAM, Incremental_approximate) {
Switching switching(5);
HybridNonlinearISAM incrementalHybrid;
HybridNonlinearFactorGraph graph1;
Values initial;
/***** Run Round 1 *****/
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(switching.nonlinearFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(1),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(switching.nonlinearFactorGraph.at(0));
initial.insert<double>(X(1), 1);
for (size_t i = 5; i <= 7; i++) {
graph1.push_back(switching.nonlinearFactorGraph.at(i));
initial.insert<double>(X(i - 3), i - 3);
}
// Run update with pruning
size_t maxComponents = 5;
incrementalHybrid.update(graph1, initial);
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with 4 hybrid nodes,
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
EXPECT_LONGS_EQUAL(4, bayesTree.size());
EXPECT_LONGS_EQUAL(
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents());
/***** Run Round 2 *****/
HybridGaussianFactorGraph graph2;
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // x4-x5
graph2.push_back(switching.nonlinearFactorGraph.at(8)); // x5 measurement
initial = Values();
initial.insert<double>(X(5), 5);
// Run update with pruning a second time.
incrementalHybrid.update(graph2, initial);
bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with pruned hybrid nodes,
// with 5 (pruned) leaves.
CHECK_EQUAL(5, bayesTree.size());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(5)]->conditional()->asMixture()->nrComponents());
}
/* ************************************************************************/
// A GTSAM-only test for running inference on a single-legged robot.
// The leg links are represented by the chain X-Y-Z-W, where X is the base and
// W is the foot.
// We use BetweenFactor<Pose2> as constraints between each of the poses.
TEST(HybridNonlinearISAM, NonTrivial) {
/*************** Run Round 1 ***************/
HybridNonlinearFactorGraph fg;
HybridNonlinearISAM inc;
// Add a prior on pose x1 at the origin.
// A prior factor consists of a mean and
// a noise model (covariance matrix)
Pose2 prior(0.0, 0.0, 0.0); // prior mean is at origin
auto priorNoise = noiseModel::Diagonal::Sigmas(
Vector3(0.3, 0.3, 0.1)); // 30cm std on x,y, 0.1 rad on theta
fg.emplace_nonlinear<PriorFactor<Pose2>>(X(0), prior, priorNoise);
// create a noise model for the landmark measurements
auto poseNoise = noiseModel::Isotropic::Sigma(3, 0.1);
// We model a robot's single leg as X - Y - Z - W
// where X is the base link and W is the foot link.
// Add connecting poses similar to PoseFactors in GTD
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(0), Y(0), Pose2(0, 1.0, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(0), Z(0), Pose2(0, 1.0, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(0), W(0), Pose2(0, 1.0, 0),
poseNoise);
// Create initial estimate
Values initial;
initial.insert(X(0), Pose2(0.0, 0.0, 0.0));
initial.insert(Y(0), Pose2(0.0, 1.0, 0.0));
initial.insert(Z(0), Pose2(0.0, 2.0, 0.0));
initial.insert(W(0), Pose2(0.0, 3.0, 0.0));
// Don't run update now since we don't have discrete variables involved.
using PlanarMotionModel = BetweenFactor<Pose2>;
/*************** Run Round 2 ***************/
// Add odometry factor with discrete modes.
Pose2 odometry(1.0, 0.0, 0.0);
KeyVector contKeys = {W(0), W(1)};
auto noise_model = noiseModel::Isotropic::Sigma(3, 1.0);
auto still = boost::make_shared<PlanarMotionModel>(W(0), W(1), Pose2(0, 0, 0),
noise_model),
moving = boost::make_shared<PlanarMotionModel>(W(0), W(1), odometry,
noise_model);
std::vector<PlanarMotionModel::shared_ptr> components = {moving, still};
auto mixtureFactor = boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components);
fg.push_back(mixtureFactor);
// Add equivalent of ImuFactor
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(0), X(1), Pose2(1.0, 0.0, 0),
poseNoise);
// PoseFactors-like at k=1
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(1), Y(1), Pose2(0, 1, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(1), Z(1), Pose2(0, 1, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(1), W(1), Pose2(-1, 1, 0),
poseNoise);
initial.insert(X(1), Pose2(1.0, 0.0, 0.0));
initial.insert(Y(1), Pose2(1.0, 1.0, 0.0));
initial.insert(Z(1), Pose2(1.0, 2.0, 0.0));
// The leg link did not move so we set the expected pose accordingly.
initial.insert(W(1), Pose2(0.0, 3.0, 0.0));
// Update without pruning
// The result is a HybridBayesNet with 1 discrete variable M(1).
// P(X | measurements) = P(W0|Z0, W1, M1) P(Z0|Y0, W1, M1) P(Y0|X0, W1, M1)
// P(X0 | X1, W1, M1) P(W1|Z1, X1, M1) P(Z1|Y1, X1, M1)
// P(Y1 | X1, M1)P(X1 | M1)P(M1)
// The MHS tree is a 1 level tree for time indices (1,) with 2 leaves.
inc.update(fg, initial);
fg = HybridNonlinearFactorGraph();
initial = Values();
/*************** Run Round 3 ***************/
// Add odometry factor with discrete modes.
contKeys = {W(1), W(2)};
still = boost::make_shared<PlanarMotionModel>(W(1), W(2), Pose2(0, 0, 0),
noise_model);
moving =
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry, noise_model);
components = {moving, still};
mixtureFactor = boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components);
fg.push_back(mixtureFactor);
// Add equivalent of ImuFactor
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(1), X(2), Pose2(1.0, 0.0, 0),
poseNoise);
// PoseFactors-like at k=1
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(2), Y(2), Pose2(0, 1, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(2), Z(2), Pose2(0, 1, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(2), W(2), Pose2(-2, 1, 0),
poseNoise);
initial.insert(X(2), Pose2(2.0, 0.0, 0.0));
initial.insert(Y(2), Pose2(2.0, 1.0, 0.0));
initial.insert(Z(2), Pose2(2.0, 2.0, 0.0));
initial.insert(W(2), Pose2(0.0, 3.0, 0.0));
// Now we prune!
// P(X | measurements) = P(W0|Z0, W1, M1) P(Z0|Y0, W1, M1) P(Y0|X0, W1, M1)
// P(X0 | X1, W1, M1) P(W1|W2, Z1, X1, M1, M2)
// P(Z1| W2, Y1, X1, M1, M2) P(Y1 | W2, X1, M1, M2)
// P(X1 | W2, X2, M1, M2) P(W2|Z2, X2, M1, M2)
// P(Z2|Y2, X2, M1, M2) P(Y2 | X2, M1, M2)
// P(X2 | M1, M2) P(M1, M2)
// The MHS at this point should be a 2 level tree on (1, 2).
// 1 has 2 choices, and 2 has 4 choices.
inc.update(fg, initial);
inc.prune(2);
fg = HybridNonlinearFactorGraph();
initial = Values();
/*************** Run Round 4 ***************/
// Add odometry factor with discrete modes.
contKeys = {W(2), W(3)};
still = boost::make_shared<PlanarMotionModel>(W(2), W(3), Pose2(0, 0, 0),
noise_model);
moving =
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry, noise_model);
components = {moving, still};
mixtureFactor = boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components);
fg.push_back(mixtureFactor);
// Add equivalent of ImuFactor
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(2), X(3), Pose2(1.0, 0.0, 0),
poseNoise);
// PoseFactors-like at k=3
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(3), Y(3), Pose2(0, 1, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(3), Z(3), Pose2(0, 1, 0),
poseNoise);
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(3), W(3), Pose2(-3, 1, 0),
poseNoise);
initial.insert(X(3), Pose2(3.0, 0.0, 0.0));
initial.insert(Y(3), Pose2(3.0, 1.0, 0.0));
initial.insert(Z(3), Pose2(3.0, 2.0, 0.0));
initial.insert(W(3), Pose2(0.0, 3.0, 0.0));
// Keep pruning!
inc.update(fg, initial);
inc.prune(3);
fg = HybridNonlinearFactorGraph();
initial = Values();
HybridGaussianISAM bayesTree = inc.bayesTree();
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph;
discreteGraph.push_back(discreteTree);
DiscreteValues optimal_assignment = discreteGraph.optimize();
DiscreteValues expected_assignment;
expected_assignment[M(1)] = 1;
expected_assignment[M(2)] = 1;
expected_assignment[M(3)] = 1;
EXPECT(assert_equal(expected_assignment, optimal_assignment));
// Test if pruning worked correctly by checking that
// we only have 3 leaves in the last node.
auto lastConditional = bayesTree[X(3)]->conditional()->asMixture();
EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}

View File

@ -33,7 +33,6 @@ namespace gtsam {
// Forward declarations // Forward declarations
template<class FACTOR> class FactorGraph; template<class FACTOR> class FactorGraph;
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree; template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
class HybridBayesTreeClique;
/* ************************************************************************* */ /* ************************************************************************* */
/** clique statistics */ /** clique statistics */

View File

@ -33,7 +33,7 @@ struct ConstructorTraversalData {
typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode; typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode;
ConstructorTraversalData* const parentData; ConstructorTraversalData* const parentData;
sharedNode myJTNode; sharedNode junctionTreeNode;
FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals; FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals;
FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors; FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors;
@ -53,8 +53,9 @@ struct ConstructorTraversalData {
// a traversal data structure with its own JT node, and create a child // a traversal data structure with its own JT node, and create a child
// pointer in its parent. // pointer in its parent.
ConstructorTraversalData myData = ConstructorTraversalData(&parentData); ConstructorTraversalData myData = ConstructorTraversalData(&parentData);
myData.myJTNode = boost::make_shared<Node>(node->key, node->factors); myData.junctionTreeNode =
parentData.myJTNode->addChild(myData.myJTNode); boost::make_shared<Node>(node->key, node->factors);
parentData.junctionTreeNode->addChild(myData.junctionTreeNode);
return myData; return myData;
} }
@ -91,7 +92,7 @@ struct ConstructorTraversalData {
myData.parentData->childSymbolicConditionals.push_back(myConditional); myData.parentData->childSymbolicConditionals.push_back(myConditional);
myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor); myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor);
sharedNode node = myData.myJTNode; sharedNode node = myData.junctionTreeNode;
const FastVector<SymbolicConditional::shared_ptr>& childConditionals = const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
myData.childSymbolicConditionals; myData.childSymbolicConditionals;
node->problemSize_ = (int) (myConditional->size() * symbolicFactors.size()); node->problemSize_ = (int) (myConditional->size() * symbolicFactors.size());
@ -138,14 +139,14 @@ JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode; typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode;
typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data; typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data;
Data rootData(0); Data rootData(0);
rootData.myJTNode = boost::make_shared<typename Base::Node>(); // Make a dummy node to gather // Make a dummy node to gather the junction tree roots
// the junction tree roots rootData.junctionTreeNode = boost::make_shared<typename Base::Node>();
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPostAlg2); Data::ConstructorTraversalVisitorPostAlg2);
// Assign roots from the dummy node // Assign roots from the dummy node
this->addChildrenAsRoots(rootData.myJTNode); this->addChildrenAsRoots(rootData.junctionTreeNode);
// Transfer remaining factors from elimination tree // Transfer remaining factors from elimination tree
Base::remainingFactors_ = eliminationTree.remainingFactors(); Base::remainingFactors_ = eliminationTree.remainingFactors();

View File

@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
EXPECT(equalsBinary(graph)); EXPECT(equalsBinary(graph));
} }
/* ****************************************************************************/
TEST(Serialization, gaussian_bayes_net) {
// Create an arbitrary Bayes Net
GaussianBayesNet gbn;
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));
std::string serialized = serialize(gbn);
GaussianBayesNet actual;
deserialize(serialized, actual);
EXPECT(assert_equal(gbn, actual));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST (Serialization, gaussian_bayes_tree) { TEST (Serialization, gaussian_bayes_tree) {
const Key x1=1, x2=2, x3=3, x4=4; const Key x1=1, x2=2, x3=3, x4=4;

View File

@ -0,0 +1,136 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DsfTrackGenerator.cpp
* @date October 2022
* @author John Lambert
* @brief Identifies connected components in the keypoint matches graph.
*/
#include <gtsam/sfm/DsfTrackGenerator.h>
#include <algorithm>
#include <iostream>
namespace gtsam {
namespace gtsfm {
typedef DSFMap<IndexPair> DSFMapIndexPair;
/// Generate the DSF to form tracks.
static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
DSFMapIndexPair dsf;
for (const auto& kv : matches) {
const auto pair_indices = kv.first;
const auto corr_indices = kv.second;
// Image pair is (i1,i2).
size_t i1 = pair_indices.first;
size_t i2 = pair_indices.second;
for (size_t k = 0; k < corr_indices.rows(); k++) {
// Measurement indices are found in a single matrix row, as (k1,k2).
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
// Unique key for DSF is (i,k), representing keypoint index in an image.
dsf.merge({i1, k1}, {i2, k2});
}
}
return dsf;
}
/// Generate a single track from a set of index pairs
static SfmTrack2d trackFromIndexPairs(const std::set<IndexPair>& index_pair_set,
const KeypointsVector& keypoints) {
// Initialize track from measurements.
SfmTrack2d track2d;
for (const auto& index_pair : index_pair_set) {
// Camera index is represented by i, and measurement index is
// represented by k.
size_t i = index_pair.i();
size_t k = index_pair.j();
// Add measurement to this track.
track2d.addMeasurement(i, keypoints[i].coordinates.row(k));
}
return track2d;
}
/// Generate tracks from the DSF.
static std::vector<SfmTrack2d> tracksFromDSF(const DSFMapIndexPair& dsf,
const KeypointsVector& keypoints) {
const std::map<IndexPair, std::set<IndexPair> > key_sets = dsf.sets();
// Return immediately if no sets were found.
if (key_sets.empty()) return {};
// Create a list of tracks.
// Each track will be represented as a list of (camera_idx, measurements).
std::vector<SfmTrack2d> tracks2d;
for (const auto& kv : key_sets) {
// Initialize track from measurements.
SfmTrack2d track2d = trackFromIndexPairs(kv.second, keypoints);
tracks2d.emplace_back(track2d);
}
return tracks2d;
}
/**
* @brief Creates a list of tracks from 2d point correspondences.
*
* Creates a disjoint-set forest (DSF) and 2d tracks from pairwise matches.
* We create a singleton for union-find set elements from camera index of a
* detection and the index of that detection in that camera's keypoint list,
* i.e. (i,k).
*
* @param Map from (i1,i2) image pair indices to (K,2) matrix, for K
* correspondence indices, from each image.
* @param Length-N list of keypoints, for N images/cameras.
*/
std::vector<SfmTrack2d> tracksFromPairwiseMatches(
const MatchIndicesMap& matches, const KeypointsVector& keypoints,
bool verbose) {
// Generate the DSF to form tracks.
if (verbose) std::cout << "[SfmTrack2d] Starting Union-Find..." << std::endl;
DSFMapIndexPair dsf = generateDSF(matches);
if (verbose) std::cout << "[SfmTrack2d] Union-Find Complete" << std::endl;
std::vector<SfmTrack2d> tracks2d = tracksFromDSF(dsf, keypoints);
// Filter out erroneous tracks that had repeated measurements within the
// same image. This is an expected result from an incorrect correspondence
// slipping through.
std::vector<SfmTrack2d> validTracks;
std::copy_if(
tracks2d.begin(), tracks2d.end(), std::back_inserter(validTracks),
[](const SfmTrack2d& track2d) { return track2d.hasUniqueCameras(); });
if (verbose) {
size_t erroneous_track_count = tracks2d.size() - validTracks.size();
double erroneous_percentage = static_cast<float>(erroneous_track_count) /
static_cast<float>(tracks2d.size()) * 100;
std::cout << std::fixed << std::setprecision(2);
std::cout << "DSF Union-Find: " << erroneous_percentage;
std::cout << "% of tracks discarded from multiple obs. in a single image."
<< std::endl;
}
// TODO(johnwlambert): return the Transitivity failure percentage here.
return tracks2d;
}
} // namespace gtsfm
} // namespace gtsam

View File

@ -0,0 +1,78 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DsfTrackGenerator.h
* @date July 2022
* @author John Lambert
* @brief Identifies connected components in the keypoint matches graph.
*/
#pragma once
#include <gtsam/base/DSFMap.h>
#include <gtsam/sfm/SfmTrack.h>
#include <Eigen/Core>
#include <map>
#include <optional>
#include <vector>
namespace gtsam {
namespace gtsfm {
typedef Eigen::MatrixX2i CorrespondenceIndices; // N x 2 array
// Output of detections in an image.
// Coordinate system convention:
// 1. The x coordinate denotes the horizontal direction (+ve direction towards
// the right).
// 2. The y coordinate denotes the vertical direction (+ve direction downwards).
// 3. Origin is at the top left corner of the image.
struct Keypoints {
// The (x, y) coordinates of the features, of shape Nx2.
Eigen::MatrixX2d coordinates;
// Optional scale of the detections, of shape N.
// Note: gtsam::Vector is typedef'd for Eigen::VectorXd.
boost::optional<gtsam::Vector> scales;
/// Optional confidences/responses for each detection, of shape N.
boost::optional<gtsam::Vector> responses;
Keypoints(const Eigen::MatrixX2d& coordinates)
: coordinates(coordinates){}; // boost::none
};
using KeypointsVector = std::vector<Keypoints>;
// Mapping from each image pair to (N,2) array representing indices of matching
// keypoints.
using MatchIndicesMap = std::map<IndexPair, CorrespondenceIndices>;
/**
* @brief Creates a list of tracks from 2d point correspondences.
*
* Creates a disjoint-set forest (DSF) and 2d tracks from pairwise matches.
* We create a singleton for union-find set elements from camera index of a
* detection and the index of that detection in that camera's keypoint list,
* i.e. (i,k).
*
* @param Map from (i1,i2) image pair indices to (K,2) matrix, for K
* correspondence indices, from each image.
* @param Length-N list of keypoints, for N images/cameras.
*/
std::vector<SfmTrack2d> tracksFromPairwiseMatches(
const MatchIndicesMap& matches, const KeypointsVector& keypoints,
bool verbose = false);
} // namespace gtsfm
} // namespace gtsam

View File

@ -22,6 +22,7 @@
#include <gtsam/geometry/Point2.h> #include <gtsam/geometry/Point2.h>
#include <gtsam/geometry/Point3.h> #include <gtsam/geometry/Point3.h>
#include <Eigen/Core>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -35,28 +36,26 @@ typedef std::pair<size_t, Point2> SfmMeasurement;
typedef std::pair<size_t, size_t> SiftIndex; typedef std::pair<size_t, size_t> SiftIndex;
/** /**
* @brief An SfmTrack stores SfM measurements grouped in a track * @brief Track containing 2D measurements associated with a single 3D point.
* @ingroup sfm * Note: Equivalent to gtsam.SfmTrack, but without the 3d measurement.
* This class holds data temporarily before 3D point is initialized.
*/ */
struct GTSAM_EXPORT SfmTrack { struct GTSAM_EXPORT SfmTrack2d {
Point3 p; ///< 3D position of the point
float r, g, b; ///< RGB color of the 3D point
/// The 2D image projections (id,(u,v)) /// The 2D image projections (id,(u,v))
std::vector<SfmMeasurement> measurements; std::vector<SfmMeasurement> measurements;
/// The feature descriptors /// The feature descriptors (optional)
std::vector<SiftIndex> siftIndices; std::vector<SiftIndex> siftIndices;
/// @name Constructors /// @name Constructors
/// @{ /// @{
explicit SfmTrack(float r = 0, float g = 0, float b = 0) // Default constructor.
: p(0, 0, 0), r(r), g(g), b(b) {} SfmTrack2d() = default;
explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0, // Constructor from measurements.
float b = 0) explicit SfmTrack2d(const std::vector<SfmMeasurement>& measurements)
: p(pt), r(r), g(g), b(b) {} : measurements(measurements) {}
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
@ -78,6 +77,70 @@ struct GTSAM_EXPORT SfmTrack {
/// Get the SIFT feature index corresponding to the measurement at `idx` /// Get the SIFT feature index corresponding to the measurement at `idx`
const SiftIndex& siftIndex(size_t idx) const { return siftIndices[idx]; } const SiftIndex& siftIndex(size_t idx) const { return siftIndices[idx]; }
/**
* @brief Check that no two measurements are from the same camera.
* @returns boolean result of the validation.
*/
bool hasUniqueCameras() const {
std::vector<int> track_cam_indices;
for (auto& measurement : measurements) {
track_cam_indices.emplace_back(measurement.first);
}
auto i =
std::adjacent_find(track_cam_indices.begin(), track_cam_indices.end());
bool all_cameras_unique = (i == track_cam_indices.end());
return all_cameras_unique;
}
/// @}
/// @name Vectorized Interface
/// @{
/// @brief Return the measurements as a 2D matrix
Eigen::MatrixX2d measurementMatrix() const {
Eigen::MatrixX2d m(numberMeasurements(), 2);
for (size_t i = 0; i < numberMeasurements(); i++) {
m.row(i) = measurement(i).second;
}
return m;
}
/// @brief Return the camera indices of the measurements
Eigen::VectorXi indexVector() const {
Eigen::VectorXi v(numberMeasurements());
for (size_t i = 0; i < numberMeasurements(); i++) {
v(i) = measurement(i).first;
}
return v;
}
/// @}
};
using SfmTrack2dVector = std::vector<SfmTrack2d>;
/**
* @brief An SfmTrack stores SfM measurements grouped in a track
* @addtogroup sfm
*/
struct GTSAM_EXPORT SfmTrack : SfmTrack2d {
Point3 p; ///< 3D position of the point
float r, g, b; ///< RGB color of the 3D point
/// @name Constructors
/// @{
explicit SfmTrack(float r = 0, float g = 0, float b = 0)
: p(0, 0, 0), r(r), g(g), b(b) {}
explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0,
float b = 0)
: p(pt), r(r), g(g), b(b) {}
/// @}
/// @name Standard Interface
/// @{
/// Get 3D point /// Get 3D point
const Point3& point3() const { return p; } const Point3& point3() const { return p; }

View File

@ -4,10 +4,23 @@
namespace gtsam { namespace gtsam {
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/Values.h>
#include <gtsam/sfm/SfmTrack.h> #include <gtsam/sfm/SfmTrack.h>
class SfmTrack { class SfmTrack2d {
std::vector<pair<size_t, gtsam::Point2>> measurements;
SfmTrack2d();
SfmTrack2d(const std::vector<gtsam::SfmMeasurement>& measurements);
size_t numberMeasurements() const;
pair<size_t, gtsam::Point2> measurement(size_t idx) const;
pair<size_t, size_t> siftIndex(size_t idx) const;
void addMeasurement(size_t idx, const gtsam::Point2& m);
gtsam::SfmMeasurement measurement(size_t idx) const;
bool hasUniqueCameras() const;
Eigen::MatrixX2d measurementMatrix() const;
Eigen::VectorXi indexVector() const;
};
virtual class SfmTrack : gtsam::SfmTrack2d {
SfmTrack(); SfmTrack();
SfmTrack(const gtsam::Point3& pt); SfmTrack(const gtsam::Point3& pt);
const Point3& point3() const; const Point3& point3() const;
@ -18,13 +31,6 @@ class SfmTrack {
double g; double g;
double b; double b;
std::vector<pair<size_t, gtsam::Point2>> measurements;
size_t numberMeasurements() const;
pair<size_t, gtsam::Point2> measurement(size_t idx) const;
pair<size_t, size_t> siftIndex(size_t idx) const;
void addMeasurement(size_t idx, const gtsam::Point2& m);
// enabling serialization functionality // enabling serialization functionality
void serialize() const; void serialize() const;
@ -32,6 +38,8 @@ class SfmTrack {
bool equals(const gtsam::SfmTrack& expected, double tol) const; bool equals(const gtsam::SfmTrack& expected, double tol) const;
}; };
#include <gtsam/nonlinear/Values.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/sfm/SfmData.h> #include <gtsam/sfm/SfmData.h>
class SfmData { class SfmData {
SfmData(); SfmData();
@ -115,7 +123,7 @@ class BinaryMeasurementsRot3 {
#include <gtsam/sfm/ShonanAveraging.h> #include <gtsam/sfm/ShonanAveraging.h>
// TODO(frank): copy/pasta below until we have integer template paremeters in // TODO(frank): copy/pasta below until we have integer template parameters in
// wrap! // wrap!
class ShonanAveragingParameters2 { class ShonanAveragingParameters2 {
@ -310,4 +318,38 @@ class TranslationRecovery {
const gtsam::BinaryMeasurementsUnit3& relativeTranslations) const; const gtsam::BinaryMeasurementsUnit3& relativeTranslations) const;
}; };
namespace gtsfm {
#include <gtsam/sfm/DsfTrackGenerator.h>
class MatchIndicesMap {
MatchIndicesMap();
MatchIndicesMap(const gtsam::gtsfm::MatchIndicesMap& other);
size_t size() const;
bool empty() const;
void clear();
gtsam::gtsfm::CorrespondenceIndices at(const pair<size_t, size_t>& keypair) const;
};
class Keypoints {
Keypoints(const Eigen::MatrixX2d& coordinates);
Eigen::MatrixX2d coordinates;
};
class KeypointsVector {
KeypointsVector();
KeypointsVector(const gtsam::gtsfm::KeypointsVector& other);
void push_back(const gtsam::gtsfm::Keypoints& keypoints);
size_t size() const;
bool empty() const;
void clear();
gtsam::gtsfm::Keypoints at(const size_t& index) const;
};
gtsam::SfmTrack2dVector tracksFromPairwiseMatches(
const gtsam::gtsfm::MatchIndicesMap& matches_dict,
const gtsam::gtsfm::KeypointsVector& keypoints_list, bool verbose = false);
} // namespace gtsfm
} // namespace gtsam } // namespace gtsam

View File

@ -0,0 +1,53 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TestSfmTrack.cpp
* @date October 2022
* @author Frank Dellaert
* @brief tests for SfmTrack class
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/sfm/SfmTrack.h>
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST(SfmTrack2d, defaultConstructor) {
SfmTrack2d track;
EXPECT_LONGS_EQUAL(0, track.measurements.size());
EXPECT_LONGS_EQUAL(0, track.siftIndices.size());
}
/* ************************************************************************* */
TEST(SfmTrack2d, measurementConstructor) {
SfmTrack2d track({{0, Point2(1, 2)}, {1, Point2(3, 4)}});
EXPECT_LONGS_EQUAL(2, track.measurements.size());
EXPECT_LONGS_EQUAL(0, track.siftIndices.size());
}
/* ************************************************************************* */
TEST(SfmTrack, construction) {
SfmTrack track(Point3(1, 2, 3), 4, 5, 6);
EXPECT(assert_equal(Point3(1, 2, 3), track.point3()));
EXPECT(assert_equal(Point3(4, 5, 6), track.rgb()));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -51,7 +51,10 @@ set(ignore
gtsam::BinaryMeasurementsUnit3 gtsam::BinaryMeasurementsUnit3
gtsam::BinaryMeasurementsRot3 gtsam::BinaryMeasurementsRot3
gtsam::DiscreteKey gtsam::DiscreteKey
gtsam::KeyPairDoubleMap) gtsam::KeyPairDoubleMap
gtsam::gtsfm::MatchIndicesMap
gtsam::gtsfm::KeypointsVector
gtsam::gtsfm::SfmTrack2dVector)
set(interface_headers set(interface_headers
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
@ -148,8 +151,12 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON)
gtsam::CameraSetCal3Bundler gtsam::CameraSetCal3Bundler
gtsam::CameraSetCal3Unified gtsam::CameraSetCal3Unified
gtsam::CameraSetCal3Fisheye gtsam::CameraSetCal3Fisheye
gtsam::KeyPairDoubleMap) gtsam::KeyPairDoubleMap
gtsam::gtsfm::MatchIndicesMap
gtsam::gtsfm::KeypointsVector
gtsam::gtsfm::SfmTrack2dVector)
pybind_wrap(${GTSAM_PYTHON_UNSTABLE_TARGET} # target pybind_wrap(${GTSAM_PYTHON_UNSTABLE_TARGET} # target
${PROJECT_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i # interface_header ${PROJECT_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i # interface_header
"gtsam_unstable.cpp" # generated_cpp "gtsam_unstable.cpp" # generated_cpp

4
python/gtsam/gtsfm.py Normal file
View File

@ -0,0 +1,4 @@
# This trick is to allow direct import of sub-modules
# without this, we can only do `from gtsam.gtsam.gtsfm import X`
# with this trick, we can do `from gtsam.gtsfm import X`
from .gtsam.gtsfm import *

View File

@ -15,12 +15,13 @@
// #include <pybind11/stl.h> // #include <pybind11/stl.h>
#include <pybind11/stl_bind.h> #include <pybind11/stl_bind.h>
PYBIND11_MAKE_OPAQUE( PYBIND11_MAKE_OPAQUE(std::vector<gtsam::SfmMeasurement>);
std::vector<gtsam::SfmTrack>); PYBIND11_MAKE_OPAQUE(std::vector<gtsam::SfmTrack>);
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::SfmCamera>);
PYBIND11_MAKE_OPAQUE(
std::vector<gtsam::SfmCamera>);
PYBIND11_MAKE_OPAQUE( PYBIND11_MAKE_OPAQUE(
std::vector<gtsam::BinaryMeasurement<gtsam::Unit3>>); std::vector<gtsam::BinaryMeasurement<gtsam::Unit3>>);
PYBIND11_MAKE_OPAQUE( PYBIND11_MAKE_OPAQUE(
std::vector<gtsam::BinaryMeasurement<gtsam::Rot3>>); std::vector<gtsam::BinaryMeasurement<gtsam::Rot3>>);
PYBIND11_MAKE_OPAQUE(
std::vector<gtsam::gtsfm::Keypoints>);
PYBIND11_MAKE_OPAQUE(gtsam::gtsfm::MatchIndicesMap);

View File

@ -18,16 +18,11 @@ py::bind_vector<std::vector<gtsam::BinaryMeasurement<gtsam::Unit3> > >(
py::bind_vector<std::vector<gtsam::BinaryMeasurement<gtsam::Rot3> > >( py::bind_vector<std::vector<gtsam::BinaryMeasurement<gtsam::Rot3> > >(
m_, "BinaryMeasurementsRot3"); m_, "BinaryMeasurementsRot3");
py::bind_map<gtsam::KeyPairDoubleMap>(m_, "KeyPairDoubleMap"); py::bind_map<gtsam::KeyPairDoubleMap>(m_, "KeyPairDoubleMap");
py::bind_vector<std::vector<gtsam::SfmTrack2d>>(m_, "SfmTrack2dVector");
py::bind_vector<std::vector<gtsam::SfmTrack>>(m_, "SfmTracks");
py::bind_vector<std::vector<gtsam::SfmCamera>>(m_, "SfmCameras");
py::bind_vector<std::vector<std::pair<size_t, gtsam::Point2>>>(
m_, "SfmMeasurementVector");
py::bind_vector< py::bind_map<gtsam::gtsfm::MatchIndicesMap>(m_, "MatchIndicesMap");
std::vector<gtsam::SfmTrack> >( py::bind_vector<std::vector<gtsam::gtsfm::Keypoints>>(m_, "KeypointsVector");
m_, "SfmTracks");
py::bind_vector<
std::vector<gtsam::SfmCamera> >(
m_, "SfmCameras");
py::bind_vector<
std::vector<std::pair<size_t, gtsam::Point2>>>(
m_, "SfmMeasurementVector"
);

View File

@ -15,8 +15,7 @@ from __future__ import print_function
import unittest import unittest
from typing import Tuple from typing import Tuple
import gtsam from gtsam import DSFMapIndexPair, IndexPair, IndexPairSetAsArray
from gtsam import IndexPair
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
@ -29,10 +28,10 @@ class TestDSFMap(GtsamTestCase):
def key(index_pair) -> Tuple[int, int]: def key(index_pair) -> Tuple[int, int]:
return index_pair.i(), index_pair.j() return index_pair.i(), index_pair.j()
dsf = gtsam.DSFMapIndexPair() dsf = DSFMapIndexPair()
pair1 = gtsam.IndexPair(1, 18) pair1 = IndexPair(1, 18)
self.assertEqual(key(dsf.find(pair1)), key(pair1)) self.assertEqual(key(dsf.find(pair1)), key(pair1))
pair2 = gtsam.IndexPair(2, 2) pair2 = IndexPair(2, 2)
# testing the merge feature of dsf # testing the merge feature of dsf
dsf.merge(pair1, pair2) dsf.merge(pair1, pair2)
@ -45,7 +44,7 @@ class TestDSFMap(GtsamTestCase):
k'th detected keypoint in image i. For the data below, merging such k'th detected keypoint in image i. For the data below, merging such
measurements into feature tracks across frames should create 2 distinct sets. measurements into feature tracks across frames should create 2 distinct sets.
""" """
dsf = gtsam.DSFMapIndexPair() dsf = DSFMapIndexPair()
dsf.merge(IndexPair(0, 1), IndexPair(1, 2)) dsf.merge(IndexPair(0, 1), IndexPair(1, 2))
dsf.merge(IndexPair(0, 1), IndexPair(3, 4)) dsf.merge(IndexPair(0, 1), IndexPair(3, 4))
dsf.merge(IndexPair(4, 5), IndexPair(6, 8)) dsf.merge(IndexPair(4, 5), IndexPair(6, 8))
@ -56,7 +55,7 @@ class TestDSFMap(GtsamTestCase):
for i in sets: for i in sets:
set_keys = [] set_keys = []
s = sets[i] s = sets[i]
for val in gtsam.IndexPairSetAsArray(s): for val in IndexPairSetAsArray(s):
set_keys.append((val.i(), val.j())) set_keys.append((val.i(), val.j()))
merged_sets.add(tuple(set_keys)) merged_sets.add(tuple(set_keys))

View File

@ -0,0 +1,96 @@
"""Unit tests for track generation using a Disjoint Set Forest data structure.
Authors: John Lambert
"""
import unittest
import gtsam
import numpy as np
from gtsam import IndexPair, KeypointsVector, MatchIndicesMap, Point2, SfmMeasurementVector, SfmTrack2d
from gtsam.gtsfm import Keypoints
from gtsam.utils.test_case import GtsamTestCase
class TestDsfTrackGenerator(GtsamTestCase):
"""Tests for DsfTrackGenerator."""
def test_track_generation(self) -> None:
"""Ensures that DSF generates three tracks from measurements
in 3 images (H=200,W=400)."""
kps_i0 = Keypoints(np.array([[10.0, 20], [30, 40]]))
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
keypoints_list = KeypointsVector()
keypoints_list.append(kps_i0)
keypoints_list.append(kps_i1)
keypoints_list.append(kps_i2)
# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding image indices (k1,k2).
matches_dict = MatchIndicesMap()
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict,
keypoints_list,
verbose=False,
)
assert len(tracks) == 3
# Verify track 0.
track0 = tracks[0]
assert track0.numberMeasurements() == 2
np.testing.assert_allclose(track0.measurements[0][1], Point2(10, 20))
np.testing.assert_allclose(track0.measurements[1][1], Point2(50, 60))
assert track0.measurements[0][0] == 0
assert track0.measurements[1][0] == 1
np.testing.assert_allclose(
track0.measurementMatrix(),
[
[10, 20],
[50, 60],
],
)
np.testing.assert_allclose(track0.indexVector(), [0, 1])
# Verify track 1.
track1 = tracks[1]
np.testing.assert_allclose(
track1.measurementMatrix(),
[
[30, 40],
[70, 80],
[130, 140],
],
)
np.testing.assert_allclose(track1.indexVector(), [0, 1, 2])
# Verify track 2.
track2 = tracks[2]
np.testing.assert_allclose(
track2.measurementMatrix(),
[
[90, 100],
[110, 120],
],
)
np.testing.assert_allclose(track2.indexVector(), [1, 2])
class TestSfmTrack2d(GtsamTestCase):
"""Tests for SfmTrack2d."""
def test_sfm_track_2d_constructor(self) -> None:
""" """
measurements = SfmMeasurementVector()
measurements.append((0, Point2(10, 20)))
track = SfmTrack2d(measurements=measurements)
track.measurement(0)
track.numberMeasurements() == 1
if __name__ == "__main__":
unittest.main()