Merge branch 'develop' into fix/doxygen
commit
33e24265bd
|
|
@ -101,8 +101,6 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX)
|
|||
# Copy matlab.h to the correct folder.
|
||||
configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h
|
||||
${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY)
|
||||
# Add the include directories so that matlab.h can be found
|
||||
include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}")
|
||||
|
||||
add_subdirectory(wrap)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ else()
|
|||
find_dependency(Boost @BOOST_FIND_MINIMUM_VERSION@ COMPONENTS @BOOST_FIND_MINIMUM_COMPONENTS@)
|
||||
endif()
|
||||
|
||||
if(@GTSAM_USE_SYSTEM_EIGEN@)
|
||||
find_dependency(Eigen3 REQUIRED)
|
||||
endif()
|
||||
|
||||
# Load exports
|
||||
include(${OUR_CMAKE_DIR}/@PACKAGE_NAME@-exports.cmake)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
# - Try to find Eigen3 lib
|
||||
#
|
||||
# This module supports requiring a minimum version, e.g. you can do
|
||||
# find_package(Eigen3 3.1.2)
|
||||
# to require version 3.1.2 or newer of Eigen3.
|
||||
#
|
||||
# Once done this will define
|
||||
#
|
||||
# EIGEN3_FOUND - system has eigen lib with correct version
|
||||
# EIGEN3_INCLUDE_DIR - the eigen include directory
|
||||
# EIGEN3_VERSION - eigen version
|
||||
|
||||
# Copyright (c) 2006, 2007 Montel Laurent, <montel@kde.org>
|
||||
# Copyright (c) 2008, 2009 Gael Guennebaud, <g.gael@free.fr>
|
||||
# Copyright (c) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
|
||||
# Redistribution and use is allowed according to the terms of the 2-clause BSD license.
|
||||
|
||||
if(NOT Eigen3_FIND_VERSION)
|
||||
if(NOT Eigen3_FIND_VERSION_MAJOR)
|
||||
set(Eigen3_FIND_VERSION_MAJOR 2)
|
||||
endif(NOT Eigen3_FIND_VERSION_MAJOR)
|
||||
if(NOT Eigen3_FIND_VERSION_MINOR)
|
||||
set(Eigen3_FIND_VERSION_MINOR 91)
|
||||
endif(NOT Eigen3_FIND_VERSION_MINOR)
|
||||
if(NOT Eigen3_FIND_VERSION_PATCH)
|
||||
set(Eigen3_FIND_VERSION_PATCH 0)
|
||||
endif(NOT Eigen3_FIND_VERSION_PATCH)
|
||||
|
||||
set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}")
|
||||
endif(NOT Eigen3_FIND_VERSION)
|
||||
|
||||
macro(_eigen3_check_version)
|
||||
file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header)
|
||||
|
||||
string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}")
|
||||
set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}")
|
||||
string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}")
|
||||
set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}")
|
||||
string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}")
|
||||
set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}")
|
||||
|
||||
set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION})
|
||||
if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
|
||||
set(EIGEN3_VERSION_OK FALSE)
|
||||
else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
|
||||
set(EIGEN3_VERSION_OK TRUE)
|
||||
endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
|
||||
|
||||
if(NOT EIGEN3_VERSION_OK)
|
||||
|
||||
message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, "
|
||||
"but at least version ${Eigen3_FIND_VERSION} is required")
|
||||
endif(NOT EIGEN3_VERSION_OK)
|
||||
endmacro(_eigen3_check_version)
|
||||
|
||||
if (EIGEN3_INCLUDE_DIR)
|
||||
|
||||
# in cache already
|
||||
_eigen3_check_version()
|
||||
set(EIGEN3_FOUND ${EIGEN3_VERSION_OK})
|
||||
|
||||
else (EIGEN3_INCLUDE_DIR)
|
||||
|
||||
find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library
|
||||
PATHS
|
||||
${CMAKE_INSTALL_PREFIX}/include
|
||||
${KDE4_INCLUDE_DIR}
|
||||
PATH_SUFFIXES eigen3 eigen
|
||||
)
|
||||
|
||||
if(EIGEN3_INCLUDE_DIR)
|
||||
_eigen3_check_version()
|
||||
endif(EIGEN3_INCLUDE_DIR)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK)
|
||||
|
||||
mark_as_advanced(EIGEN3_INCLUDE_DIR)
|
||||
|
||||
endif(EIGEN3_INCLUDE_DIR)
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
###############################################################################
|
||||
# Option for using system Eigen or GTSAM-bundled Eigen
|
||||
# Default: Use system's Eigen if found automatically:
|
||||
find_package(Eigen3 QUIET)
|
||||
find_package(Eigen3 CONFIG QUIET)
|
||||
set(USE_SYSTEM_EIGEN_INITIAL_VALUE ${Eigen3_FOUND})
|
||||
option(GTSAM_USE_SYSTEM_EIGEN "Find and use system-installed Eigen. If 'off', use the one bundled with GTSAM" ${USE_SYSTEM_EIGEN_INITIAL_VALUE})
|
||||
unset(USE_SYSTEM_EIGEN_INITIAL_VALUE)
|
||||
|
|
@ -14,10 +14,14 @@ endif()
|
|||
|
||||
# Switch for using system Eigen or GTSAM-bundled Eigen
|
||||
if(GTSAM_USE_SYSTEM_EIGEN)
|
||||
find_package(Eigen3 REQUIRED) # need to find again as REQUIRED
|
||||
# Since Eigen 3.3.0 a Eigen3Config.cmake is available so use it.
|
||||
find_package(Eigen3 CONFIG REQUIRED) # need to find again as REQUIRED
|
||||
|
||||
# Use generic Eigen include paths e.g. <Eigen/Core>
|
||||
set(GTSAM_EIGEN_INCLUDE_FOR_INSTALL "${EIGEN3_INCLUDE_DIR}")
|
||||
# The actual include directory (for BUILD cmake target interface):
|
||||
# Note: EIGEN3_INCLUDE_DIR points to some random location on some eigen
|
||||
# versions. So here I use the target itself to get the proper include
|
||||
# directory (it is generated by cmake, thus has the correct path)
|
||||
get_target_property(GTSAM_EIGEN_INCLUDE_FOR_BUILD Eigen3::Eigen INTERFACE_INCLUDE_DIRECTORIES)
|
||||
|
||||
# check if MKL is also enabled - can have one or the other, but not both!
|
||||
# Note: Eigen >= v3.2.5 includes our patches
|
||||
|
|
@ -30,9 +34,6 @@ if(GTSAM_USE_SYSTEM_EIGEN)
|
|||
if(EIGEN_USE_MKL_ALL AND (EIGEN3_VERSION VERSION_EQUAL 3.3.4))
|
||||
message(FATAL_ERROR "MKL does not work with Eigen 3.3.4 because of a bug in Eigen. See http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1527. Disable GTSAM_USE_SYSTEM_EIGEN to use GTSAM's copy of Eigen, disable GTSAM_WITH_EIGEN_MKL, or upgrade/patch your installation of Eigen.")
|
||||
endif()
|
||||
|
||||
# The actual include directory (for BUILD cmake target interface):
|
||||
set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${EIGEN3_INCLUDE_DIR}")
|
||||
else()
|
||||
# Use bundled Eigen include path.
|
||||
# Clear any variables set by FindEigen3
|
||||
|
|
@ -46,6 +47,19 @@ else()
|
|||
|
||||
# The actual include directory (for BUILD cmake target interface):
|
||||
set(GTSAM_EIGEN_INCLUDE_FOR_BUILD "${GTSAM_SOURCE_DIR}/gtsam/3rdparty/Eigen/")
|
||||
|
||||
add_library(gtsam_eigen3 INTERFACE)
|
||||
|
||||
target_include_directories(gtsam_eigen3 INTERFACE
|
||||
$<BUILD_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_BUILD}>
|
||||
$<INSTALL_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_INSTALL}>
|
||||
)
|
||||
add_library(Eigen3::Eigen ALIAS gtsam_eigen3)
|
||||
|
||||
install(TARGETS gtsam_eigen3 EXPORT GTSAM-exports PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
list(APPEND GTSAM_EXPORTED_TARGETS gtsam_eigen3)
|
||||
set(GTSAM_EXPORTED_TARGETS "${GTSAM_EXPORTED_TARGETS}" PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
# Detect Eigen version:
|
||||
|
|
|
|||
|
|
@ -117,12 +117,9 @@ set_target_properties(gtsam PROPERTIES
|
|||
VERSION ${gtsam_version}
|
||||
SOVERSION ${gtsam_soversion})
|
||||
|
||||
# Append Eigen include path, set in top-level CMakeLists.txt to either
|
||||
# system-eigen, or GTSAM eigen path
|
||||
target_include_directories(gtsam PUBLIC
|
||||
$<BUILD_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_BUILD}>
|
||||
$<INSTALL_INTERFACE:${GTSAM_EIGEN_INCLUDE_FOR_INSTALL}>
|
||||
)
|
||||
target_link_libraries(gtsam PUBLIC Eigen3::Eigen)
|
||||
|
||||
# MKL include dir:
|
||||
if (GTSAM_USE_EIGEN_MKL)
|
||||
target_include_directories(gtsam PUBLIC ${MKL_INCLUDE_DIR})
|
||||
|
|
|
|||
|
|
@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str,
|
|||
PrintForestVisitorPre visitor(keyFormatter);
|
||||
DepthFirstForest(forest, str, visitor);
|
||||
}
|
||||
}
|
||||
} // namespace treeTraversal
|
||||
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -11,15 +11,17 @@
|
|||
|
||||
/**
|
||||
* @file Assignment.h
|
||||
* @brief An assignment from labels to a discrete value index (size_t)
|
||||
* @brief An assignment from labels to a discrete value index (size_t)
|
||||
* @author Frank Dellaert
|
||||
* @date Feb 5, 2012
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -33,13 +35,30 @@ namespace gtsam {
|
|||
*/
|
||||
template <class L>
|
||||
class Assignment : public std::map<L, size_t> {
|
||||
/**
|
||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||
* printing.
|
||||
*
|
||||
* @param x The value passed to format.
|
||||
* @return std::string
|
||||
*/
|
||||
static std::string DefaultFormatter(const L& x) {
|
||||
std::stringstream ss;
|
||||
ss << x;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
using std::map<L, size_t>::operator=;
|
||||
|
||||
void print(const std::string& s = "Assignment: ") const {
|
||||
void print(const std::string& s = "Assignment: ",
|
||||
const std::function<std::string(L)>& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
std::cout << s << ": ";
|
||||
for (const typename Assignment::value_type& keyValue : *this)
|
||||
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
||||
for (const typename Assignment::value_type& keyValue : *this) {
|
||||
std::cout << "(" << labelFormatter(keyValue.first) << ", "
|
||||
<< keyValue.second << ")";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,4 +48,25 @@ namespace gtsam {
|
|||
return keys & key2;
|
||||
}
|
||||
|
||||
void DiscreteKeys::print(const std::string& s,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
for (auto&& dkey : *this) {
|
||||
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
|
||||
if (this->size() != other.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
if (this->at(i).first != other.at(i).first ||
|
||||
this->at(i).second != other.at(i).second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/global_includes.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <boost/serialization/vector.hpp>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -70,8 +71,30 @@ namespace gtsam {
|
|||
push_back(key);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Print the keys and cardinalities.
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Check equality to another DiscreteKeys object.
|
||||
bool equals(const DiscreteKeys& other, double tol = 0) const;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& boost::serialization::make_nvp(
|
||||
"DiscreteKeys",
|
||||
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
|
||||
}
|
||||
|
||||
}; // DiscreteKeys
|
||||
|
||||
/// Create a list from two keys
|
||||
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||
}
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) {
|
|||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
||||
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
||||
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteFactorGraph separatorMarginal9 =
|
||||
|
|
|
|||
|
|
@ -16,14 +16,29 @@
|
|||
* @author Duy-Nguyen Ta
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DisreteKeys, Serialization) {
|
||||
DiscreteKeys keys;
|
||||
keys& DiscreteKey(0, 2);
|
||||
keys& DiscreteKey(1, 3);
|
||||
keys& DiscreteKey(2, 4);
|
||||
|
||||
EXPECT(equalsObj<DiscreteKeys>(keys));
|
||||
EXPECT(equalsXML<DiscreteKeys>(keys));
|
||||
EXPECT(equalsBinary<DiscreteKeys>(keys));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
|
|
@ -31,4 +46,3 @@ int main() {
|
|||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ static const Point2 P(0.2, 0.7);
|
|||
static const Rot2 R = Rot2::fromAngle(0.3);
|
||||
static const double s = 4;
|
||||
|
||||
const double degree = M_PI / 180;
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Similarity2, Concepts) {
|
||||
BOOST_CONCEPT_ASSERT((IsGroup<Similarity2>));
|
||||
|
|
|
|||
|
|
@ -66,6 +66,27 @@ class KeySet {
|
|||
void serialize() const;
|
||||
};
|
||||
|
||||
// Actually a vector<Key>, needed for Matlab
|
||||
class KeyVector {
|
||||
KeyVector();
|
||||
KeyVector(const gtsam::KeyVector& other);
|
||||
|
||||
// Note: no print function
|
||||
|
||||
// common STL methods
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
void clear();
|
||||
|
||||
// structure specific methods
|
||||
size_t at(size_t i) const;
|
||||
size_t front() const;
|
||||
size_t back() const;
|
||||
void push_back(size_t key) const;
|
||||
|
||||
void serialize() const;
|
||||
};
|
||||
|
||||
// Actually a FastMap<Key,int>
|
||||
class KeyGroupMap {
|
||||
KeyGroupMap();
|
||||
|
|
|
|||
|
|
@ -119,33 +119,90 @@ void GaussianMixture::print(const std::string &s,
|
|||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
||||
RedirectCout rd;
|
||||
if (gf && !gf->empty())
|
||||
if (gf && !gf->empty()) {
|
||||
gf->print("", formatter);
|
||||
else
|
||||
return {"nullptr"};
|
||||
return rd.str();
|
||||
return rd.str();
|
||||
} else {
|
||||
return "nullptr";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = [&decisionTree](
|
||||
/* ************************************************************************* */
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param decisionTree The probability decision tree of only discrete keys.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the gaussian mixture.
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
|
||||
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
|
||||
if (decisionTree(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
// Case where the gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (gaussianMixtureKeySet == decisionTreeKeySet) {
|
||||
if (decisionTree(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
} else {
|
||||
return conditional;
|
||||
}
|
||||
} else {
|
||||
return conditional;
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||
gaussianMixtureKeySet.begin(),
|
||||
gaussianMixtureKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment.begin(), assignment.end());
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this conditional.
|
||||
if (decisionTree(augmented_values) > 0.0) {
|
||||
return conditional;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = prunerFunc(decisionTree);
|
||||
|
||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
|
|
|
|||
|
|
@ -70,6 +70,17 @@ class GTSAM_EXPORT GaussianMixture
|
|||
*/
|
||||
Sum asGaussianFactorGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Helper function to get the pruner functor.
|
||||
*
|
||||
* @param decisionTree The pruned discrete probability decision tree.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
prunerFunc(const DecisionTreeFactor &decisionTree);
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
|
||||
RedirectCout rd;
|
||||
std::cout << ":\n";
|
||||
if (gf)
|
||||
if (gf && !gf->empty()) {
|
||||
gf->print("", formatter);
|
||||
else
|
||||
return {"nullptr"};
|
||||
return rd.str();
|
||||
return rd.str();
|
||||
} else {
|
||||
return "nullptr";
|
||||
}
|
||||
});
|
||||
std::cout << "}" << std::endl;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,23 +15,40 @@
|
|||
* @date January 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
return s;
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
|
||||
// The canonical decision tree factor which will get the discrete conditionals
|
||||
// added to it.
|
||||
DecisionTreeFactor dtFactor;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
||||
dtFactor = dtFactor * f;
|
||||
}
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
// Get the decision tree of only the discrete keys
|
||||
auto discreteConditionals = this->discreteConditionals();
|
||||
const DecisionTreeFactor::shared_ptr discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
|
|
@ -41,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
|
||||
HybridBayesNet prunedBayesNetFragment;
|
||||
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
|
||||
if ((*discreteFactor)(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
} else {
|
||||
return conditional;
|
||||
}
|
||||
};
|
||||
|
||||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per discreteFactor.
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
|
||||
GaussianMixture::shared_ptr gaussianMixture =
|
||||
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
|
||||
if (conditional->isHybrid()) {
|
||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
||||
|
||||
if (gaussianMixture) {
|
||||
// We may have mixtures with less discrete keys than discreteFactor so we
|
||||
// skip those since the label assignment does not exist.
|
||||
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
||||
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
|
||||
if (gmKeySet != dfKeySet) {
|
||||
// Add the gaussianMixture which doesn't have to be pruned.
|
||||
prunedBayesNetFragment.push_back(
|
||||
boost::make_shared<HybridConditional>(gaussianMixture));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Run the pruning to get a new, pruned tree
|
||||
GaussianMixture::Conditionals prunedTree =
|
||||
gaussianMixture->conditionals().apply(pruner);
|
||||
|
||||
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
|
||||
// reverse keys to get a natural ordering
|
||||
std::reverse(discreteKeys.begin(), discreteKeys.end());
|
||||
|
||||
// Convert from boost::iterator_range to KeyVector
|
||||
// so we can pass it to constructor.
|
||||
KeyVector frontals(gaussianMixture->frontals().begin(),
|
||||
gaussianMixture->frontals().end()),
|
||||
parents(gaussianMixture->parents().begin(),
|
||||
gaussianMixture->parents().end());
|
||||
|
||||
// Create the new gaussian mixture and add it to the bayes net.
|
||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
|
||||
frontals, parents, discreteKeys, prunedTree);
|
||||
// Make a copy of the gaussian mixture and prune it!
|
||||
auto prunedGaussianMixture =
|
||||
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
||||
prunedGaussianMixture->prune(*discreteFactor);
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(
|
||||
|
|
@ -111,14 +85,18 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
||||
return factors_.at(i)->asMixture();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
return factors_.at(i)->asGaussian();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||
return boost::dynamic_pointer_cast<DiscreteConditional>(
|
||||
factors_.at(i)->inner());
|
||||
return factors_.at(i)->asDiscreteConditional();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -126,16 +104,45 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
GaussianMixture gm = *this->atGaussian(idx);
|
||||
gbn.push_back(gm(assignment));
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixture gm = *this->atMixture(idx);
|
||||
gbn.push_back(gm(assignment));
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, add gaussian conditional.
|
||||
gbn.push_back((this->atGaussian(idx)));
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we simply continue.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return gbn;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
auto dag = HybridLookupDAG::FromBayesNet(*this);
|
||||
return dag.argmax();
|
||||
// Solve for the MPE
|
||||
DiscreteBayesNet discrete_bn;
|
||||
for (auto &conditional : factors_) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||
}
|
||||
}
|
||||
|
||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
GaussianBayesNet gbn = this->choose(mpe);
|
||||
return HybridValues(mpe, gbn.optimize());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn = this->choose(assignment);
|
||||
return gbn.optimize();
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/global_includes.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
|
@ -39,12 +40,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty bayes net */
|
||||
HybridBayesNet() = default;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This &bn, double tol = 1e-9) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/// print graph
|
||||
void print(
|
||||
const std::string &s = "",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Add HybridConditional to Bayes Net
|
||||
using Base::add;
|
||||
|
|
@ -55,8 +75,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
|
||||
}
|
||||
|
||||
using Base::push_back;
|
||||
|
||||
/// Get a specific Gaussian mixture by index `i`.
|
||||
GaussianMixture::shared_ptr atGaussian(size_t i) const;
|
||||
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
||||
|
||||
/// Get a specific Gaussian conditional by index `i`.
|
||||
GaussianConditional::shared_ptr atGaussian(size_t i) const;
|
||||
|
||||
/// Get a specific discrete conditional by index `i`.
|
||||
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
||||
|
|
@ -70,10 +95,49 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Solve the HybridBayesNet by back-substitution.
|
||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
||||
/// put this method there?
|
||||
/**
|
||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||
* discrete variables and then optimizing the continuous variables based on
|
||||
* the MPE assignment.
|
||||
*
|
||||
* @return HybridValues
|
||||
*/
|
||||
HybridValues optimize() const;
|
||||
|
||||
/**
|
||||
* @brief Given the discrete assignment, return the optimized estimate for the
|
||||
* selected Gaussian BayesNet.
|
||||
*
|
||||
* @param assignment An assignment of discrete values.
|
||||
* @return Values
|
||||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||
*
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
*/
|
||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||
|
||||
public:
|
||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
/// traits
|
||||
template <>
|
||||
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/inference/BayesTree-inst.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
|
@ -35,4 +38,161 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
|||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesTree::optimize() const {
|
||||
DiscreteBayesNet dbn;
|
||||
DiscreteValues mpe;
|
||||
|
||||
auto root = roots_.at(0);
|
||||
// Access the clique and get the underlying hybrid conditional
|
||||
HybridConditional::shared_ptr root_conditional = root->conditional();
|
||||
|
||||
// The root should be discrete only, we compute the MPE
|
||||
if (root_conditional->isDiscrete()) {
|
||||
dbn.push_back(root_conditional->asDiscreteConditional());
|
||||
mpe = DiscreteFactorGraph(dbn).optimize();
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridBayesTree root is not discrete-only. Please check elimination "
|
||||
"ordering or use continuous factor graph.");
|
||||
}
|
||||
|
||||
VectorValues values = optimize(mpe);
|
||||
return HybridValues(mpe, values);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
|
||||
*
|
||||
* When traversing the tree, the pre-order visitor will receive an instance of
|
||||
* this class with the parent clique data.
|
||||
*/
|
||||
struct HybridAssignmentData {
|
||||
const DiscreteValues assignment_;
|
||||
GaussianBayesTree::sharedNode parentClique_;
|
||||
// The gaussian bayes tree that will be recursively created.
|
||||
GaussianBayesTree* gaussianbayesTree_;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Assignment Data object.
|
||||
*
|
||||
* @param assignment The MPE assignment for the optimal Gaussian cliques.
|
||||
* @param parentClique The clique from the parent node of the current node.
|
||||
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
|
||||
*/
|
||||
HybridAssignmentData(const DiscreteValues& assignment,
|
||||
const GaussianBayesTree::sharedNode& parentClique,
|
||||
GaussianBayesTree* gbt)
|
||||
: assignment_(assignment),
|
||||
parentClique_(parentClique),
|
||||
gaussianbayesTree_(gbt) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The HybridAssignmentData from the parent node.
|
||||
* @return HybridAssignmentData which is passed to the children.
|
||||
*/
|
||||
static HybridAssignmentData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& node,
|
||||
HybridAssignmentData& parentData) {
|
||||
// Extract the gaussian conditional from the Hybrid clique
|
||||
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
||||
GaussianConditional::shared_ptr conditional;
|
||||
if (hybrid_conditional->isHybrid()) {
|
||||
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
||||
} else if (hybrid_conditional->isContinuous()) {
|
||||
conditional = hybrid_conditional->asGaussian();
|
||||
} else {
|
||||
// Discrete only conditional, so we set to empty gaussian conditional
|
||||
conditional = boost::make_shared<GaussianConditional>();
|
||||
}
|
||||
|
||||
// Create the GaussianClique for the current node
|
||||
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||
// Add the current clique to the GaussianBayesTree.
|
||||
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
|
||||
|
||||
// Create new HybridAssignmentData where the current node is the parent
|
||||
// This will be passed down to the children nodes
|
||||
HybridAssignmentData data(parentData.assignment_, clique,
|
||||
parentData.gaussianbayesTree_);
|
||||
return data;
|
||||
}
|
||||
};
|
||||
|
||||
/* *************************************************************************
|
||||
*/
|
||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||
GaussianBayesTree gbt;
|
||||
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
|
||||
VectorValues result = gbt.optimize();
|
||||
|
||||
// Return the optimized bayes net result.
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->roots_.at(0)->conditional()->inner());
|
||||
|
||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDecisionTree.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDecisionTree;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDecisionTree(prunedDecisionTree) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The data from the parent node.
|
||||
* @return HybridPrunerData which is passed to the children.
|
||||
*/
|
||||
static HybridPrunerData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& clique,
|
||||
HybridPrunerData& parentData) {
|
||||
// Get the conditional
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// If conditional is hybrid, we prune it.
|
||||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
gaussianMixture->prune(parentData.prunedDecisionTree);
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDecisionTree, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -73,9 +73,46 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
/** Check equality */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
/**
|
||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||
* set of discrete variables and using it to compute the best continuous
|
||||
* update delta.
|
||||
*
|
||||
* @return HybridValues
|
||||
*/
|
||||
HybridValues optimize() const;
|
||||
|
||||
/**
|
||||
* @brief Recursively optimize the BayesTree to produce a vector solution.
|
||||
*
|
||||
* @param assignment The discrete values assignment to select the Gaussian
|
||||
* mixtures.
|
||||
* @return VectorValues
|
||||
*/
|
||||
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const size_t maxNumberLeaves);
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
/// traits
|
||||
template <>
|
||||
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
|
||||
|
||||
/**
|
||||
* @brief Class for Hybrid Bayes tree orphan subtrees.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridGaussianFactorGraph;
|
||||
|
||||
/**
|
||||
* Hybrid Conditional Density
|
||||
*
|
||||
|
|
@ -71,7 +69,7 @@ class GTSAM_EXPORT HybridConditional
|
|||
BaseConditional; ///< Typedef to our conditional base class
|
||||
|
||||
protected:
|
||||
// Type-erased pointer to the inner type
|
||||
/// Type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner_;
|
||||
|
||||
public:
|
||||
|
|
@ -129,8 +127,7 @@ class GTSAM_EXPORT HybridConditional
|
|||
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
||||
* HybridConditional.
|
||||
*/
|
||||
HybridConditional(
|
||||
boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianMixture
|
||||
|
|
@ -142,6 +139,17 @@ class GTSAM_EXPORT HybridConditional
|
|||
return boost::static_pointer_cast<GaussianMixture>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianConditional
|
||||
*
|
||||
* @return GaussianConditional::shared_ptr
|
||||
*/
|
||||
GaussianConditional::shared_ptr asGaussian() {
|
||||
if (!isContinuous())
|
||||
throw std::invalid_argument("Not a continuous conditional");
|
||||
return boost::static_pointer_cast<GaussianConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
*
|
||||
|
|
@ -170,10 +178,19 @@ class GTSAM_EXPORT HybridConditional
|
|||
/// Get the type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||
|
||||
}; // DiscreteConditional
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
}
|
||||
|
||||
}; // HybridConditional
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<HybridConditional> : public Testable<DiscreteConditional> {};
|
||||
struct traits<HybridConditional> : public Testable<HybridConditional> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -50,10 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
|
|||
|
||||
/* ************************************************************************ */
|
||||
HybridFactor::HybridFactor(const KeyVector &keys)
|
||||
: Base(keys),
|
||||
isContinuous_(true),
|
||||
nrContinuous_(keys.size()),
|
||||
continuousKeys_(keys) {}
|
||||
: Base(keys), isContinuous_(true), continuousKeys_(keys) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
|
||||
|
|
@ -62,7 +59,6 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys,
|
|||
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
|
||||
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
|
||||
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
|
||||
nrContinuous_(continuousKeys.size()),
|
||||
discreteKeys_(discreteKeys),
|
||||
continuousKeys_(continuousKeys) {}
|
||||
|
||||
|
|
@ -103,7 +99,6 @@ void HybridFactor::print(const std::string &s,
|
|||
if (d < discreteKeys_.size() - 1) {
|
||||
std::cout << " ";
|
||||
}
|
||||
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,8 +49,6 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
bool isContinuous_ = false;
|
||||
bool isHybrid_ = false;
|
||||
|
||||
size_t nrContinuous_ = 0;
|
||||
|
||||
protected:
|
||||
// Set of DiscreteKeys for this factor.
|
||||
DiscreteKeys discreteKeys_;
|
||||
|
|
@ -131,6 +129,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
|
||||
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
|
||||
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
|
||||
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
|
||||
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
|
||||
}
|
||||
};
|
||||
// HybridFactor
|
||||
|
||||
|
|
|
|||
|
|
@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
|||
push_hybrid(p);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all the discrete keys in the factor graph.
|
||||
const KeySet discreteKeys() const {
|
||||
KeySet discrete_keys;
|
||||
for (auto& factor : factors_) {
|
||||
for (const DiscreteKey& k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
return discrete_keys;
|
||||
}
|
||||
|
||||
/// Get all the continuous keys in the factor graph.
|
||||
const KeySet continuousKeys() const {
|
||||
KeySet keys;
|
||||
for (auto& factor : factors_) {
|
||||
for (const Key& key : factor->continuousKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals(
|
|||
}
|
||||
|
||||
} else if (f->isContinuous()) {
|
||||
deferredFactors.push_back(
|
||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
|
||||
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
deferredFactors.push_back(gf->inner());
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
deferredFactors.push_back(cg->asGaussian());
|
||||
}
|
||||
|
||||
} else if (f->isDiscrete()) {
|
||||
// Don't do anything for discrete-only factors
|
||||
|
|
@ -135,9 +139,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
for (auto &fp : factors) {
|
||||
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||
gfg.push_back(ptr->inner());
|
||||
} else if (auto p =
|
||||
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
|
||||
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
||||
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
|
||||
gfg.push_back(
|
||||
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
|
||||
} else {
|
||||
// It is an orphan wrapped conditional
|
||||
}
|
||||
|
|
@ -153,12 +157,14 @@ std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
|||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto &fp : factors) {
|
||||
if (auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp)) {
|
||||
dfg.push_back(ptr->inner());
|
||||
} else if (auto p =
|
||||
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
|
||||
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
|
||||
|
||||
for (auto &factor : factors) {
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) {
|
||||
dfg.push_back(p->inner());
|
||||
} else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) {
|
||||
auto discrete_conditional =
|
||||
boost::static_pointer_cast<DiscreteConditional>(p->inner());
|
||||
dfg.push_back(discrete_conditional);
|
||||
} else {
|
||||
// It is an orphan wrapper
|
||||
}
|
||||
|
|
@ -213,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
result = EliminatePreferCholesky(graph, frontalKeys);
|
||||
|
||||
if (keysOfEliminated.empty()) {
|
||||
keysOfEliminated =
|
||||
result.first->keys(); // Initialize the keysOfEliminated to be the
|
||||
// Initialize the keysOfEliminated to be the keys of the
|
||||
// eliminated GaussianConditional
|
||||
keysOfEliminated = result.first->keys();
|
||||
}
|
||||
// keysOfEliminated of the GaussianConditional
|
||||
if (keysOfSeparator.empty()) {
|
||||
keysOfSeparator = result.second->keys();
|
||||
}
|
||||
|
|
@ -244,6 +250,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
return exp(-factor->error(empty_values));
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
||||
|
||||
auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
||||
|
|
@ -401,4 +408,19 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
|||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||
KeySet discrete_keys = discreteKeys();
|
||||
for (auto &factor : factors_) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
|
||||
const VariableIndex index(factors_);
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
return ordering;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -169,6 +169,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
Base::push_back(sharedFactor);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
*
|
||||
* @return const Ordering
|
||||
*/
|
||||
const Ordering getHybridOrdering() const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -14,9 +14,10 @@
|
|||
* @date March 31, 2022
|
||||
* @author Fan Jiang
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
|
@ -41,6 +42,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree)
|
|||
void HybridGaussianISAM::updateInternal(
|
||||
const HybridGaussianFactorGraph& newFactors,
|
||||
HybridBayesTree::Cliques* orphans,
|
||||
const boost::optional<size_t>& maxNrLeaves,
|
||||
const boost::optional<Ordering>& ordering,
|
||||
const HybridBayesTree::Eliminate& function) {
|
||||
// Remove the contaminated part of the Bayes tree
|
||||
|
|
@ -57,26 +59,28 @@ void HybridGaussianISAM::updateInternal(
|
|||
factors += newFactors;
|
||||
|
||||
// Add the orphaned subtrees
|
||||
for (const sharedClique& orphan : *orphans)
|
||||
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
|
||||
|
||||
KeySet allDiscrete;
|
||||
for (auto& factor : factors) {
|
||||
for (auto& k : factor->discreteKeys()) {
|
||||
allDiscrete.insert(k.first);
|
||||
}
|
||||
for (const sharedClique& orphan : *orphans) {
|
||||
factors += boost::make_shared<BayesTreeOrphanWrapper<Node>>(orphan);
|
||||
}
|
||||
|
||||
// Get all the discrete keys from the factors
|
||||
KeySet allDiscrete = factors.discreteKeys();
|
||||
|
||||
// Create KeyVector with continuous keys followed by discrete keys.
|
||||
KeyVector newKeysDiscreteLast;
|
||||
// Insert continuous keys first.
|
||||
for (auto& k : newFactorKeys) {
|
||||
if (!allDiscrete.exists(k)) {
|
||||
newKeysDiscreteLast.push_back(k);
|
||||
}
|
||||
}
|
||||
// Insert discrete keys at the end
|
||||
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||
std::back_inserter(newKeysDiscreteLast));
|
||||
|
||||
// Get an ordering where the new keys are eliminated last
|
||||
const VariableIndex index(factors);
|
||||
|
||||
Ordering elimination_ordering;
|
||||
if (ordering) {
|
||||
elimination_ordering = *ordering;
|
||||
|
|
@ -91,6 +95,10 @@ void HybridGaussianISAM::updateInternal(
|
|||
HybridBayesTree::shared_ptr bayesTree =
|
||||
factors.eliminateMultifrontal(elimination_ordering, function, index);
|
||||
|
||||
if (maxNrLeaves) {
|
||||
bayesTree->prune(*maxNrLeaves);
|
||||
}
|
||||
|
||||
// Re-add into Bayes tree data structures
|
||||
this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(),
|
||||
bayesTree->roots().end());
|
||||
|
|
@ -99,61 +107,11 @@ void HybridGaussianISAM::updateInternal(
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
|
||||
const boost::optional<size_t>& maxNrLeaves,
|
||||
const boost::optional<Ordering>& ordering,
|
||||
const HybridBayesTree::Eliminate& function) {
|
||||
Cliques orphans;
|
||||
this->updateInternal(newFactors, &orphans, ordering, function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Check if `b` is a subset of `a`.
|
||||
* Non-const since they need to be sorted.
|
||||
*
|
||||
* @param a KeyVector
|
||||
* @param b KeyVector
|
||||
* @return True if the keys of b is a subset of a, else false.
|
||||
*/
|
||||
bool IsSubset(KeyVector a, KeyVector b) {
|
||||
std::sort(a.begin(), a.end());
|
||||
std::sort(b.begin(), b.end());
|
||||
return std::includes(a.begin(), a.end(), b.begin(), b.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->clique(root)->conditional()->inner());
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
|
||||
std::vector<gtsam::Key> prunedKeys;
|
||||
for (auto&& clique : nodes()) {
|
||||
// The cliques can be repeated for each frontal so we record it in
|
||||
// prunedKeys and check if we have already pruned a particular clique.
|
||||
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
|
||||
prunedKeys.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add all the keys of the current clique to be pruned to prunedKeys
|
||||
for (auto&& key : clique.second->conditional()->frontals()) {
|
||||
prunedKeys.push_back(key);
|
||||
}
|
||||
|
||||
// Convert parents() to a KeyVector for comparison
|
||||
KeyVector parents;
|
||||
for (auto&& parent : clique.second->conditional()->parents()) {
|
||||
parents.push_back(parent);
|
||||
}
|
||||
|
||||
if (IsSubset(parents, decisionTree->keys())) {
|
||||
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
clique.second->conditional()->inner());
|
||||
|
||||
gaussianMixture->prune(prunedDiscreteFactor);
|
||||
}
|
||||
}
|
||||
this->updateInternal(newFactors, &orphans, maxNrLeaves, ordering, function);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
|
|||
void updateInternal(
|
||||
const HybridGaussianFactorGraph& newFactors,
|
||||
HybridBayesTree::Cliques* orphans,
|
||||
const boost::optional<size_t>& maxNrLeaves = boost::none,
|
||||
const boost::optional<Ordering>& ordering = boost::none,
|
||||
const HybridBayesTree::Eliminate& function =
|
||||
HybridBayesTree::EliminationTraitsType::DefaultEliminate);
|
||||
|
|
@ -62,20 +63,15 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
|
|||
* @brief Perform update step with new factors.
|
||||
*
|
||||
* @param newFactors Factor graph of new factors to add and eliminate.
|
||||
* @param maxNrLeaves The maximum number of leaves to keep after pruning.
|
||||
* @param ordering Custom elimination ordering.
|
||||
* @param function Elimination function.
|
||||
*/
|
||||
void update(const HybridGaussianFactorGraph& newFactors,
|
||||
const boost::optional<size_t>& maxNrLeaves = boost::none,
|
||||
const boost::optional<Ordering>& ordering = boost::none,
|
||||
const HybridBayesTree::Eliminate& function =
|
||||
HybridBayesTree::EliminationTraitsType::DefaultEliminate);
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param root The root key in the discrete conditional decision tree.
|
||||
* @param maxNumberLeaves
|
||||
*/
|
||||
void prune(const Key& root, const size_t maxNumberLeaves);
|
||||
};
|
||||
|
||||
/// traits
|
||||
|
|
|
|||
|
|
@ -31,9 +31,7 @@ template class EliminatableClusterTree<HybridBayesTree,
|
|||
template class JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>;
|
||||
|
||||
struct HybridConstructorTraversalData {
|
||||
typedef
|
||||
typename JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>::Node
|
||||
Node;
|
||||
typedef HybridJunctionTree::Node Node;
|
||||
typedef
|
||||
typename JunctionTree<HybridBayesTree,
|
||||
HybridGaussianFactorGraph>::sharedNode sharedNode;
|
||||
|
|
@ -62,6 +60,7 @@ struct HybridConstructorTraversalData {
|
|||
data.junctionTreeNode = boost::make_shared<Node>(node->key, node->factors);
|
||||
parentData.junctionTreeNode->addChild(data.junctionTreeNode);
|
||||
|
||||
// Add all the discrete keys in the hybrid factors to the current data
|
||||
for (HybridFactor::shared_ptr& f : node->factors) {
|
||||
for (auto& k : f->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
|
|
@ -72,8 +71,8 @@ struct HybridConstructorTraversalData {
|
|||
}
|
||||
|
||||
// Post-order visitor function
|
||||
static void ConstructorTraversalVisitorPostAlg2(
|
||||
const boost::shared_ptr<HybridEliminationTree::Node>& ETreeNode,
|
||||
static void ConstructorTraversalVisitorPost(
|
||||
const boost::shared_ptr<HybridEliminationTree::Node>& node,
|
||||
const HybridConstructorTraversalData& data) {
|
||||
// In this post-order visitor, we combine the symbolic elimination results
|
||||
// from the elimination tree children and symbolically eliminate the current
|
||||
|
|
@ -86,15 +85,15 @@ struct HybridConstructorTraversalData {
|
|||
|
||||
// Do symbolic elimination for this node
|
||||
SymbolicFactors symbolicFactors;
|
||||
symbolicFactors.reserve(ETreeNode->factors.size() +
|
||||
symbolicFactors.reserve(node->factors.size() +
|
||||
data.childSymbolicFactors.size());
|
||||
// Add ETree node factors
|
||||
symbolicFactors += ETreeNode->factors;
|
||||
symbolicFactors += node->factors;
|
||||
// Add symbolic factors passed up from children
|
||||
symbolicFactors += data.childSymbolicFactors;
|
||||
|
||||
Ordering keyAsOrdering;
|
||||
keyAsOrdering.push_back(ETreeNode->key);
|
||||
keyAsOrdering.push_back(node->key);
|
||||
SymbolicConditional::shared_ptr conditional;
|
||||
SymbolicFactor::shared_ptr separatorFactor;
|
||||
boost::tie(conditional, separatorFactor) =
|
||||
|
|
@ -105,19 +104,19 @@ struct HybridConstructorTraversalData {
|
|||
data.parentData->childSymbolicFactors.push_back(separatorFactor);
|
||||
data.parentData->discreteKeys.merge(data.discreteKeys);
|
||||
|
||||
sharedNode node = data.junctionTreeNode;
|
||||
sharedNode jt_node = data.junctionTreeNode;
|
||||
const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
|
||||
data.childSymbolicConditionals;
|
||||
node->problemSize_ = (int)(conditional->size() * symbolicFactors.size());
|
||||
jt_node->problemSize_ = (int)(conditional->size() * symbolicFactors.size());
|
||||
|
||||
// Merge our children if they are in our clique - if our conditional has
|
||||
// exactly one fewer parent than our child's conditional.
|
||||
const size_t nrParents = conditional->nrParents();
|
||||
const size_t nrChildren = node->nrChildren();
|
||||
const size_t nrChildren = jt_node->nrChildren();
|
||||
assert(childConditionals.size() == nrChildren);
|
||||
|
||||
// decide which children to merge, as index into children
|
||||
std::vector<size_t> nrChildrenFrontals = node->nrFrontalsOfChildren();
|
||||
std::vector<size_t> nrChildrenFrontals = jt_node->nrFrontalsOfChildren();
|
||||
std::vector<bool> merge(nrChildren, false);
|
||||
size_t nrFrontals = 1;
|
||||
for (size_t i = 0; i < nrChildren; i++) {
|
||||
|
|
@ -137,7 +136,7 @@ struct HybridConstructorTraversalData {
|
|||
}
|
||||
|
||||
// now really merge
|
||||
node->mergeChildren(merge);
|
||||
jt_node->mergeChildren(merge);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -161,7 +160,7 @@ HybridJunctionTree::HybridJunctionTree(
|
|||
// the junction tree roots
|
||||
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
||||
Data::ConstructorTraversalVisitorPre,
|
||||
Data::ConstructorTraversalVisitorPostAlg2);
|
||||
Data::ConstructorTraversalVisitorPost);
|
||||
|
||||
// Assign roots from the dummy node
|
||||
this->addChildrenAsRoots(rootData.junctionTreeNode);
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.cpp
|
||||
* @date Aug, 2022
|
||||
* @author Shangjie Xue
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using std::pair;
|
||||
using std::vector;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
|
||||
// For discrete conditional, uses argmaxInPlace() method in
|
||||
// DiscreteLookupTable.
|
||||
if (isDiscrete()) {
|
||||
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
|
||||
&(values->discrete));
|
||||
} else if (isContinuous()) {
|
||||
// For Gaussian conditional, uses solve() method in GaussianConditional.
|
||||
values->continuous.insert(
|
||||
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
|
||||
values->continuous));
|
||||
} else if (isHybrid()) {
|
||||
// For hybrid conditional, since children should not contain discrete
|
||||
// variable, we can condition on the discrete variable in the parents and
|
||||
// solve the resulting GaussianConditional.
|
||||
auto conditional =
|
||||
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
|
||||
values->discrete);
|
||||
values->continuous.insert(conditional->solve(values->continuous));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
|
||||
HybridLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
HybridLookupTable hlt(*conditional);
|
||||
dag.push_back(hlt);
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||
lookupTable->argmaxInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridLookupDAG.h
|
||||
* @date Aug, 2022
|
||||
* @author Shangjie Xue
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @brief HybridLookupTable table for max-product
|
||||
*
|
||||
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
|
||||
* convenience. Is used in the max-product algorithm.
|
||||
*/
|
||||
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
|
||||
public:
|
||||
using Base = HybridConditional;
|
||||
using This = HybridLookupTable;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
|
||||
*
|
||||
* @param conditional input hybrid conditional
|
||||
*/
|
||||
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(HybridValues* parentsValues) const;
|
||||
};
|
||||
|
||||
/** A DAG made from hybrid lookup tables, as defined above. Similar to
|
||||
* DiscreteLookupDAG */
|
||||
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
|
||||
public:
|
||||
using Base = BayesNet<HybridLookupTable>;
|
||||
using This = HybridLookupDAG;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct empty DAG.
|
||||
HybridLookupDAG() {}
|
||||
|
||||
/// Create from BayesNet with LookupTables
|
||||
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);
|
||||
|
||||
/// Destructor
|
||||
virtual ~HybridLookupDAG() {}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteLookupTable */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief argmax by back-substitution, optionally given certain variables.
|
||||
*
|
||||
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first *and* that the
|
||||
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||
* resulted from eliminating a factor graph, this is true for the elimination
|
||||
* ordering.
|
||||
*
|
||||
* @return given assignment extended w. optimal assignment for all variables.
|
||||
*/
|
||||
HybridValues argmax(HybridValues given = HybridValues()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -27,8 +27,7 @@ void HybridNonlinearFactorGraph::add(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearFactorGraph::add(
|
||||
boost::shared_ptr<DiscreteFactor> factor) {
|
||||
void HybridNonlinearFactorGraph::add(boost::shared_ptr<DiscreteFactor> factor) {
|
||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
|
|
@ -49,12 +48,12 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
|
||||
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||
const Values& continuousValues) const {
|
||||
// create an empty linear FG
|
||||
HybridGaussianFactorGraph linearFG;
|
||||
auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();
|
||||
|
||||
linearFG.reserve(size());
|
||||
linearFG->reserve(size());
|
||||
|
||||
// linearize all hybrid factors
|
||||
for (auto&& factor : factors_) {
|
||||
|
|
@ -66,9 +65,9 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
|
|||
if (factor->isHybrid()) {
|
||||
// Check if it is a nonlinear mixture factor
|
||||
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
|
||||
linearFG.push_back(nlmf->linearize(continuousValues));
|
||||
linearFG->push_back(nlmf->linearize(continuousValues));
|
||||
} else {
|
||||
linearFG.push_back(factor);
|
||||
linearFG->push_back(factor);
|
||||
}
|
||||
|
||||
// Now check if the factor is a continuous only factor.
|
||||
|
|
@ -80,18 +79,18 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
|
|||
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
|
||||
auto hgf = boost::make_shared<HybridGaussianFactor>(
|
||||
nlf->linearize(continuousValues));
|
||||
linearFG.push_back(hgf);
|
||||
linearFG->push_back(hgf);
|
||||
} else {
|
||||
linearFG.push_back(factor);
|
||||
linearFG->push_back(factor);
|
||||
}
|
||||
// Finally if nothing else, we are discrete-only which doesn't need
|
||||
// lineariztion.
|
||||
} else {
|
||||
linearFG.push_back(factor);
|
||||
linearFG->push_back(factor);
|
||||
}
|
||||
|
||||
} else {
|
||||
linearFG.push_back(GaussianFactor::shared_ptr());
|
||||
linearFG->push_back(GaussianFactor::shared_ptr());
|
||||
}
|
||||
}
|
||||
return linearFG;
|
||||
|
|
|
|||
|
|
@ -42,6 +42,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
using IsNonlinear = typename std::enable_if<
|
||||
std::is_base_of<NonlinearFactor, FACTOR>::value>::type;
|
||||
|
||||
/// Check if T has a value_type derived from FactorType.
|
||||
template <typename T>
|
||||
using HasDerivedValueType = typename std::enable_if<
|
||||
std::is_base_of<HybridFactor, typename T::value_type>::value>::type;
|
||||
|
||||
/// Check if T has a pointer type derived from FactorType.
|
||||
template <typename T>
|
||||
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
|
||||
HybridFactor, typename T::value_type::element_type>::value>::type;
|
||||
|
||||
public:
|
||||
using Base = HybridFactorGraph;
|
||||
using This = HybridNonlinearFactorGraph; ///< this class
|
||||
|
|
@ -109,6 +119,21 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Push back many factors as shared_ptr's in a container (factors are not
|
||||
* copied)
|
||||
*/
|
||||
template <typename CONTAINER>
|
||||
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) {
|
||||
Base::push_back(container.begin(), container.end());
|
||||
}
|
||||
|
||||
/// Push back non-pointer objects in a container (factors are copied).
|
||||
template <typename CONTAINER>
|
||||
HasDerivedValueType<CONTAINER> push_back(const CONTAINER& container) {
|
||||
Base::push_back(container.begin(), container.end());
|
||||
}
|
||||
|
||||
/// Add a nonlinear factor as a shared ptr.
|
||||
void add(boost::shared_ptr<NonlinearFactor> factor);
|
||||
|
||||
|
|
@ -127,7 +152,8 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
* @param continuousValues: Dictionary of continuous values.
|
||||
* @return HybridGaussianFactorGraph::shared_ptr
|
||||
*/
|
||||
HybridGaussianFactorGraph linearize(const Values& continuousValues) const;
|
||||
HybridGaussianFactorGraph::shared_ptr linearize(
|
||||
const Values& continuousValues) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridNonlinearISAM.cpp
|
||||
* @date Sep 12, 2022
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearISAM::saveGraph(const string& s,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
isam_.saveGraph(s, keyFormatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
|
||||
const Values& initialValues,
|
||||
const boost::optional<size_t>& maxNrLeaves,
|
||||
const boost::optional<Ordering>& ordering) {
|
||||
if (newFactors.size() > 0) {
|
||||
// Reorder and relinearize every reorderInterval updates
|
||||
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
|
||||
reorder_relinearize();
|
||||
reorderCounter_ = 0;
|
||||
}
|
||||
|
||||
factors_.push_back(newFactors);
|
||||
|
||||
// Linearize new factors and insert them
|
||||
// TODO: optimize for whole config?
|
||||
linPoint_.insert(initialValues);
|
||||
|
||||
boost::shared_ptr<HybridGaussianFactorGraph> linearizedNewFactors =
|
||||
newFactors.linearize(linPoint_);
|
||||
|
||||
// Update ISAM
|
||||
isam_.update(*linearizedNewFactors, maxNrLeaves, ordering,
|
||||
eliminationFunction_);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearISAM::reorder_relinearize() {
|
||||
if (factors_.size() > 0) {
|
||||
// Obtain the new linearization point
|
||||
const Values newLinPoint = estimate();
|
||||
|
||||
isam_.clear();
|
||||
|
||||
// Just recreate the whole BayesTree
|
||||
// TODO: allow for constrained ordering here
|
||||
// TODO: decouple relinearization and reordering to avoid
|
||||
isam_.update(*factors_.linearize(newLinPoint), boost::none, boost::none,
|
||||
eliminationFunction_);
|
||||
|
||||
// Update linearization point
|
||||
linPoint_ = newLinPoint;
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Values HybridNonlinearISAM::estimate() {
|
||||
Values result;
|
||||
if (isam_.size() > 0) {
|
||||
HybridValues values = isam_.optimize();
|
||||
assignment_ = values.discrete();
|
||||
return linPoint_.retract(values.continuous());
|
||||
} else {
|
||||
return linPoint_;
|
||||
}
|
||||
}
|
||||
|
||||
// /* *************************************************************************
|
||||
// */ Matrix HybridNonlinearISAM::marginalCovariance(Key key) const {
|
||||
// return isam_.marginalCovariance(key);
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearISAM::print(const string& s,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
cout << s << "ReorderInterval: " << reorderInterval_
|
||||
<< " Current Count: " << reorderCounter_ << endl;
|
||||
isam_.print("HybridGaussianISAM:\n", keyFormatter);
|
||||
linPoint_.print("Linearization Point:\n", keyFormatter);
|
||||
factors_.print("Nonlinear Graph:\n", keyFormatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearISAM::printStats() const {
|
||||
isam_.getCliqueData().getStats().print();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridNonlinearISAM.h
|
||||
* @date Sep 12, 2022
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
/**
|
||||
* Wrapper class to manage ISAM in a nonlinear context
|
||||
*/
|
||||
class GTSAM_EXPORT HybridNonlinearISAM {
|
||||
protected:
|
||||
/** The internal iSAM object */
|
||||
gtsam::HybridGaussianISAM isam_;
|
||||
|
||||
/** The current linearization point */
|
||||
Values linPoint_;
|
||||
|
||||
/// The discrete assignment
|
||||
DiscreteValues assignment_;
|
||||
|
||||
/** The original factors, used when relinearizing */
|
||||
HybridNonlinearFactorGraph factors_;
|
||||
|
||||
/** The reordering interval and counter */
|
||||
int reorderInterval_;
|
||||
int reorderCounter_;
|
||||
|
||||
/** The elimination function */
|
||||
HybridGaussianFactorGraph::Eliminate eliminationFunction_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Periodically reorder and relinearize
|
||||
* @param reorderInterval is the number of updates between reorderings,
|
||||
* 0 never reorders (and is dangerous for memory consumption)
|
||||
* 1 (default) reorders every time, in worse case is batch every update
|
||||
* typical values are 50 or 100
|
||||
*/
|
||||
HybridNonlinearISAM(
|
||||
int reorderInterval = 1,
|
||||
const HybridGaussianFactorGraph::Eliminate& eliminationFunction =
|
||||
HybridGaussianFactorGraph::EliminationTraitsType::DefaultEliminate)
|
||||
: reorderInterval_(reorderInterval),
|
||||
reorderCounter_(0),
|
||||
eliminationFunction_(eliminationFunction) {}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Return the current solution estimate */
|
||||
Values estimate();
|
||||
|
||||
// /** find the marginal covariance for a single variable */
|
||||
// Matrix marginalCovariance(Key key) const;
|
||||
|
||||
// access
|
||||
|
||||
/** access the underlying bayes tree */
|
||||
const HybridGaussianISAM& bayesTree() const { return isam_; }
|
||||
|
||||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); }
|
||||
|
||||
/** Return the current linearization point */
|
||||
const Values& getLinearizationPoint() const { return linPoint_; }
|
||||
|
||||
/** Return the current discrete assignment */
|
||||
const DiscreteValues& getAssignment() const { return assignment_; }
|
||||
|
||||
/** get underlying nonlinear graph */
|
||||
const HybridNonlinearFactorGraph& getFactorsUnsafe() const {
|
||||
return factors_;
|
||||
}
|
||||
|
||||
/** get counters */
|
||||
int reorderInterval() const { return reorderInterval_; } ///< TODO: comment
|
||||
int reorderCounter() const { return reorderCounter_; } ///< TODO: comment
|
||||
|
||||
/** prints out all contents of the system */
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/** prints out clique statistics */
|
||||
void printStats() const;
|
||||
|
||||
/** saves the Tree to a text file in GraphViz format */
|
||||
void saveGraph(const std::string& s,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/** Add new factors along with their initial linearization points */
|
||||
void update(const HybridNonlinearFactorGraph& newFactors,
|
||||
const Values& initialValues,
|
||||
const boost::optional<size_t>& maxNrLeaves = boost::none,
|
||||
const boost::optional<Ordering>& ordering = boost::none);
|
||||
|
||||
/** Relinearization and reordering of variables */
|
||||
void reorder_relinearize();
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -31,60 +31,78 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* HybridValues represents a collection of DiscreteValues and VectorValues. It
|
||||
* is typically used to store the variables of a HybridGaussianFactorGraph.
|
||||
* HybridValues represents a collection of DiscreteValues and VectorValues.
|
||||
* It is typically used to store the variables of a HybridGaussianFactorGraph.
|
||||
* Optimizing a HybridGaussianBayesNet returns this class.
|
||||
*/
|
||||
class GTSAM_EXPORT HybridValues {
|
||||
public:
|
||||
private:
|
||||
// DiscreteValue stored the discrete components of the HybridValues.
|
||||
DiscreteValues discrete;
|
||||
DiscreteValues discrete_;
|
||||
|
||||
// VectorValue stored the continuous components of the HybridValues.
|
||||
VectorValues continuous;
|
||||
VectorValues continuous_;
|
||||
|
||||
// Default constructor creates an empty HybridValues.
|
||||
HybridValues() : discrete(), continuous(){};
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
// Construct from DiscreteValues and VectorValues.
|
||||
/// Default constructor creates an empty HybridValues.
|
||||
HybridValues() = default;
|
||||
|
||||
/// Construct from DiscreteValues and VectorValues.
|
||||
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
|
||||
: discrete(dv), continuous(cv){};
|
||||
: discrete_(dv), continuous_(cv){};
|
||||
|
||||
// print required by Testable for unit testing
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// print required by Testable for unit testing
|
||||
void print(const std::string& s = "HybridValues",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
std::cout << s << ": \n";
|
||||
discrete.print(" Discrete", keyFormatter); // print discrete components
|
||||
continuous.print(" Continuous",
|
||||
keyFormatter); // print continuous components
|
||||
discrete_.print(" Discrete", keyFormatter); // print discrete components
|
||||
continuous_.print(" Continuous",
|
||||
keyFormatter); // print continuous components
|
||||
};
|
||||
|
||||
// equals required by Testable for unit testing
|
||||
/// equals required by Testable for unit testing
|
||||
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
||||
return discrete.equals(other.discrete, tol) &&
|
||||
continuous.equals(other.continuous, tol);
|
||||
return discrete_.equals(other.discrete_, tol) &&
|
||||
continuous_.equals(other.continuous_, tol);
|
||||
}
|
||||
|
||||
// Check whether a variable with key \c j exists in DiscreteValue.
|
||||
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };
|
||||
/// @}
|
||||
/// @name Interface
|
||||
/// @{
|
||||
|
||||
// Check whether a variable with key \c j exists in VectorValue.
|
||||
bool existsVector(Key j) { return continuous.exists(j); };
|
||||
/// Return the discrete MPE assignment
|
||||
DiscreteValues discrete() const { return discrete_; }
|
||||
|
||||
// Check whether a variable with key \c j exists.
|
||||
/// Return the delta update for the continuous vectors
|
||||
VectorValues continuous() const { return continuous_; }
|
||||
|
||||
/// Check whether a variable with key \c j exists in DiscreteValue.
|
||||
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
|
||||
|
||||
/// Check whether a variable with key \c j exists in VectorValue.
|
||||
bool existsVector(Key j) { return continuous_.exists(j); };
|
||||
|
||||
/// Check whether a variable with key \c j exists.
|
||||
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
|
||||
|
||||
/** Insert a discrete \c value with key \c j. Replaces the existing value if
|
||||
* the key \c j is already used.
|
||||
* @param value The vector to be inserted.
|
||||
* @param j The index with which the value will be associated. */
|
||||
void insert(Key j, int value) { discrete[j] = value; };
|
||||
void insert(Key j, int value) { discrete_[j] = value; };
|
||||
|
||||
/** Insert a vector \c value with key \c j. Throws an invalid_argument
|
||||
* exception if the key \c j is already used.
|
||||
* @param value The vector to be inserted.
|
||||
* @param j The index with which the value will be associated. */
|
||||
void insert(Key j, const Vector& value) { continuous.insert(j, value); }
|
||||
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
|
||||
|
||||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
||||
|
||||
|
|
@ -92,13 +110,13 @@ class GTSAM_EXPORT HybridValues {
|
|||
* Read/write access to the discrete value with key \c j, throws
|
||||
* std::out_of_range if \c j does not exist.
|
||||
*/
|
||||
size_t& atDiscrete(Key j) { return discrete.at(j); };
|
||||
size_t& atDiscrete(Key j) { return discrete_.at(j); };
|
||||
|
||||
/**
|
||||
* Read/write access to the vector value with key \c j, throws
|
||||
* std::out_of_range if \c j does not exist.
|
||||
*/
|
||||
Vector& at(Key j) { return continuous.at(j); };
|
||||
Vector& at(Key j) { return continuous_.at(j); };
|
||||
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
@ -112,8 +130,8 @@ class GTSAM_EXPORT HybridValues {
|
|||
std::string html(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
std::stringstream ss;
|
||||
ss << this->discrete.html(keyFormatter);
|
||||
ss << this->continuous.html(keyFormatter);
|
||||
ss << this->discrete_.html(keyFormatter);
|
||||
ss << this->continuous_.html(keyFormatter);
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -100,11 +100,23 @@ class MixtureFactor : public HybridFactor {
|
|||
bool normalized = false)
|
||||
: Base(keys, discreteKeys), normalized_(normalized) {
|
||||
std::vector<NonlinearFactor::shared_ptr> nonlinear_factors;
|
||||
KeySet continuous_keys_set(keys.begin(), keys.end());
|
||||
KeySet factor_keys_set;
|
||||
for (auto&& f : factors) {
|
||||
// Insert all factor continuous keys in the continuous keys set.
|
||||
std::copy(f->keys().begin(), f->keys().end(),
|
||||
std::inserter(factor_keys_set, factor_keys_set.end()));
|
||||
|
||||
nonlinear_factors.push_back(
|
||||
boost::dynamic_pointer_cast<NonlinearFactor>(f));
|
||||
}
|
||||
factors_ = Factors(discreteKeys, nonlinear_factors);
|
||||
|
||||
if (continuous_keys_set != factor_keys_set) {
|
||||
throw std::runtime_error(
|
||||
"The specified continuous keys and the keys in the factors don't "
|
||||
"match!");
|
||||
}
|
||||
}
|
||||
|
||||
~MixtureFactor() = default;
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ namespace gtsam {
|
|||
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
class HybridValues {
|
||||
gtsam::DiscreteValues discrete;
|
||||
gtsam::VectorValues continuous;
|
||||
gtsam::DiscreteValues discrete() const;
|
||||
gtsam::VectorValues continuous() const;
|
||||
HybridValues();
|
||||
HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv);
|
||||
void print(string s = "HybridValues",
|
||||
|
|
@ -99,6 +99,8 @@ class HybridBayesTree {
|
|||
bool empty() const;
|
||||
const HybridBayesTreeClique* operator[](size_t j) const;
|
||||
|
||||
gtsam::HybridValues optimize() const;
|
||||
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -115,7 +115,6 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
|
|||
/* ***************************************************************************
|
||||
*/
|
||||
using MotionModel = BetweenFactor<double>;
|
||||
// using MotionMixture = MixtureFactor<MotionModel>;
|
||||
|
||||
// Test fixture with switching network.
|
||||
struct Switching {
|
||||
|
|
@ -125,12 +124,15 @@ struct Switching {
|
|||
HybridGaussianFactorGraph linearizedFactorGraph;
|
||||
Values linearizationPoint;
|
||||
|
||||
/// Create with given number of time steps.
|
||||
/**
|
||||
* @brief Create with given number of time steps.
|
||||
*
|
||||
* @param K The total number of timesteps.
|
||||
* @param between_sigma The stddev between poses.
|
||||
* @param prior_sigma The stddev on priors (also used for measurements).
|
||||
*/
|
||||
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1)
|
||||
: K(K) {
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
// Create DiscreteKeys for binary K modes, modes[0] will not be used.
|
||||
for (size_t k = 0; k <= K; k++) {
|
||||
modes.emplace_back(M(k), 2);
|
||||
|
|
@ -145,7 +147,7 @@ struct Switching {
|
|||
// Add "motion models".
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
KeyVector keys = {X(k), X(k + 1)};
|
||||
auto motion_models = motionModels(k);
|
||||
auto motion_models = motionModels(k, between_sigma);
|
||||
std::vector<NonlinearFactor::shared_ptr> components;
|
||||
for (auto &&f : motion_models) {
|
||||
components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f));
|
||||
|
|
@ -155,7 +157,7 @@ struct Switching {
|
|||
}
|
||||
|
||||
// Add measurement factors
|
||||
auto measurement_noise = noiseModel::Isotropic::Sigma(1, 0.1);
|
||||
auto measurement_noise = noiseModel::Isotropic::Sigma(1, prior_sigma);
|
||||
for (size_t k = 2; k <= K; k++) {
|
||||
nonlinearFactorGraph.emplace_nonlinear<PriorFactor<double>>(
|
||||
X(k), 1.0 * (k - 1), measurement_noise);
|
||||
|
|
@ -169,15 +171,14 @@ struct Switching {
|
|||
linearizationPoint.insert<double>(X(k), static_cast<double>(k));
|
||||
}
|
||||
|
||||
linearizedFactorGraph = nonlinearFactorGraph.linearize(linearizationPoint);
|
||||
// The ground truth is robot moving forward
|
||||
// and one less than the linearization point
|
||||
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
|
||||
}
|
||||
|
||||
// Create motion models for a given time step
|
||||
static std::vector<MotionModel::shared_ptr> motionModels(size_t k,
|
||||
double sigma = 1.0) {
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
auto noise_model = noiseModel::Isotropic::Sigma(1, sigma);
|
||||
auto still =
|
||||
boost::make_shared<MotionModel>(X(k), X(k + 1), 0.0, noise_model),
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@
|
|||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
|
||||
#include "Switching.h"
|
||||
|
||||
|
|
@ -27,6 +30,8 @@
|
|||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
|
@ -47,6 +52,20 @@ TEST(HybridBayesNet, Creation) {
|
|||
EXPECT(df.equals(expected));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test adding a bayes net to another one.
|
||||
TEST(HybridBayesNet, Add) {
|
||||
HybridBayesNet bayesNet;
|
||||
|
||||
bayesNet.add(Asia, "99/1");
|
||||
|
||||
DiscreteConditional expected(Asia, "99/1");
|
||||
|
||||
HybridBayesNet other;
|
||||
other.push_back(bayesNet);
|
||||
EXPECT(bayesNet.equals(other));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test choosing an assignment of conditionals
|
||||
TEST(HybridBayesNet, Choose) {
|
||||
|
|
@ -72,19 +91,128 @@ TEST(HybridBayesNet, Choose) {
|
|||
EXPECT_LONGS_EQUAL(4, gbn.size());
|
||||
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(0)))(assignment),
|
||||
hybridBayesNet->atMixture(0)))(assignment),
|
||||
*gbn.at(0)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(1)))(assignment),
|
||||
hybridBayesNet->atMixture(1)))(assignment),
|
||||
*gbn.at(1)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(2)))(assignment),
|
||||
hybridBayesNet->atMixture(2)))(assignment),
|
||||
*gbn.at(2)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(3)))(assignment),
|
||||
hybridBayesNet->atMixture(3)))(assignment),
|
||||
*gbn.at(3)));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net optimize
|
||||
TEST(HybridBayesNet, OptimizeAssignment) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering ordering;
|
||||
for (auto&& kvp : s.linearizationPoint) {
|
||||
ordering += kvp.key;
|
||||
}
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||
|
||||
DiscreteValues assignment;
|
||||
assignment[M(1)] = 1;
|
||||
assignment[M(2)] = 1;
|
||||
assignment[M(3)] = 1;
|
||||
|
||||
VectorValues delta = hybridBayesNet->optimize(assignment);
|
||||
|
||||
// The linearization point has the same value as the key index,
|
||||
// e.g. X(1) = 1, X(2) = 2,
|
||||
// but the factors specify X(k) = k-1, so delta should be -1.
|
||||
VectorValues expected_delta;
|
||||
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
|
||||
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
|
||||
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
|
||||
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
|
||||
|
||||
EXPECT(assert_equal(expected_delta, delta));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net optimize
|
||||
TEST(HybridBayesNet, Optimize) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
DiscreteValues expectedAssignment;
|
||||
expectedAssignment[M(1)] = 1;
|
||||
expectedAssignment[M(2)] = 0;
|
||||
expectedAssignment[M(3)] = 1;
|
||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
|
||||
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
|
||||
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
|
||||
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net multifrontal optimize
|
||||
TEST(HybridBayesNet, OptimizeMultifrontal) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree::shared_ptr hybridBayesTree =
|
||||
s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering);
|
||||
HybridValues delta = hybridBayesTree->optimize();
|
||||
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
|
||||
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
|
||||
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
|
||||
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
HybridValues pruned_delta = prunedBayesNet.optimize();
|
||||
|
||||
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
|
||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,166 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testHybridBayesTree.cpp
|
||||
* @brief Unit tests for HybridBayesTree
|
||||
* @author Varun Agrawal
|
||||
* @date August 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
||||
#include "Switching.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test for optimizing a HybridBayesTree with a given assignment.
|
||||
TEST(HybridBayesTree, OptimizeAssignment) {
|
||||
Switching s(4);
|
||||
|
||||
HybridGaussianISAM isam;
|
||||
HybridGaussianFactorGraph graph1;
|
||||
|
||||
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the Gaussian factors, 1 prior on X(1),
|
||||
// 3 measurements on X(2), X(3), X(4)
|
||||
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||
for (size_t i = 4; i <= 7; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the discrete factors
|
||||
for (size_t i = 7; i <= 9; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
isam.update(graph1);
|
||||
|
||||
DiscreteValues assignment;
|
||||
assignment[M(1)] = 1;
|
||||
assignment[M(2)] = 1;
|
||||
assignment[M(3)] = 1;
|
||||
|
||||
VectorValues delta = isam.optimize(assignment);
|
||||
|
||||
// The linearization point has the same value as the key index,
|
||||
// e.g. X(1) = 1, X(2) = 2,
|
||||
// but the factors specify X(k) = k-1, so delta should be -1.
|
||||
VectorValues expected_delta;
|
||||
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
|
||||
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
|
||||
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
|
||||
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
|
||||
|
||||
EXPECT(assert_equal(expected_delta, delta));
|
||||
|
||||
// Create ordering.
|
||||
Ordering ordering;
|
||||
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||
|
||||
GaussianBayesNet gbn = hybridBayesNet->choose(assignment);
|
||||
VectorValues expected = gbn.optimize();
|
||||
|
||||
EXPECT(assert_equal(expected, delta));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test for optimizing a HybridBayesTree.
|
||||
TEST(HybridBayesTree, Optimize) {
|
||||
Switching s(4);
|
||||
|
||||
HybridGaussianISAM isam;
|
||||
HybridGaussianFactorGraph graph1;
|
||||
|
||||
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the Gaussian factors, 1 prior on X(1),
|
||||
// 3 measurements on X(2), X(3), X(4)
|
||||
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||
for (size_t i = 4; i <= 6; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the discrete factors
|
||||
for (size_t i = 7; i <= 9; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
isam.update(graph1);
|
||||
|
||||
HybridValues delta = isam.optimize();
|
||||
|
||||
// Create ordering.
|
||||
Ordering ordering;
|
||||
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto&& f : *remainingFactorGraph) {
|
||||
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
|
||||
dfg.push_back(
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
|
||||
}
|
||||
|
||||
DiscreteValues expectedMPE = dfg.optimize();
|
||||
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
|
||||
|
||||
EXPECT(assert_equal(expectedMPE, delta.discrete()));
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridBayesTree, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
|||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||
|
||||
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal(
|
||||
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
|
||||
HybridBayesTree::shared_ptr result =
|
||||
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
|
||||
|
||||
// The bayes tree should have 3 cliques
|
||||
EXPECT_LONGS_EQUAL(3, result->size());
|
||||
|
|
@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
|
|||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
|
||||
|
||||
// Get a constrained ordering keeping c1 last
|
||||
auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)});
|
||||
auto ordering_full = hfg.getHybridOrdering();
|
||||
|
||||
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
|
||||
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
|
||||
|
|
@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
|||
}
|
||||
HybridBayesNet::shared_ptr hbn;
|
||||
HybridGaussianFactorGraph::shared_ptr remaining;
|
||||
std::tie(hbn, remaining) =
|
||||
hfg->eliminatePartialSequential(ordering_partial);
|
||||
std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial);
|
||||
|
||||
EXPECT_LONGS_EQUAL(14, hbn->size());
|
||||
EXPECT_LONGS_EQUAL(11, remaining->size());
|
||||
|
|
@ -501,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(HybridGaussianFactorGraph, optimize) {
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
||||
|
|
@ -522,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) {
|
|||
|
||||
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test adding of gaussian conditional and re-elimination.
|
||||
TEST(HybridGaussianFactorGraph, Conditionals) {
|
||||
Switching switching(4);
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(1));
|
||||
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);
|
||||
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
|
||||
hfg.push_back(*bayes_net);
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
|
||||
ordering.push_back(X(2));
|
||||
ordering.push_back(X(3));
|
||||
ordering.push_back(M(1));
|
||||
ordering.push_back(M(2));
|
||||
|
||||
bayes_net = hfg.eliminateSequential(ordering);
|
||||
|
||||
HybridValues result = bayes_net->optimize();
|
||||
|
||||
Values expected_continuous;
|
||||
expected_continuous.insert<double>(X(1), 0);
|
||||
expected_continuous.insert<double>(X(2), 1);
|
||||
expected_continuous.insert<double>(X(3), 2);
|
||||
expected_continuous.insert<double>(X(4), 4);
|
||||
Values result_continuous =
|
||||
switching.linearizationPoint.retract(result.continuous());
|
||||
EXPECT(assert_equal(expected_continuous, result_continuous));
|
||||
|
||||
DiscreteValues expected_discrete;
|
||||
expected_discrete[M(1)] = 1;
|
||||
expected_discrete[M(2)] = 1;
|
||||
EXPECT(assert_equal(expected_discrete, result.discrete()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
|||
size_t maxNrLeaves = 5;
|
||||
incrementalHybrid.update(graph1);
|
||||
|
||||
incrementalHybrid.prune(M(3), maxNrLeaves);
|
||||
incrementalHybrid.prune(maxNrLeaves);
|
||||
|
||||
/*
|
||||
unpruned factor is:
|
||||
|
|
@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
// Run update with pruning
|
||||
size_t maxComponents = 5;
|
||||
incrementalHybrid.update(graph1);
|
||||
incrementalHybrid.prune(M(3), maxComponents);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
|
|
@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
EXPECT_LONGS_EQUAL(
|
||||
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
|
|
@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
|
||||
// Run update with pruning a second time.
|
||||
incrementalHybrid.update(graph2);
|
||||
incrementalHybrid.prune(M(4), maxComponents);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
|
|
@ -399,7 +399,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
initial.insert(Z(0), Pose2(0.0, 2.0, 0.0));
|
||||
initial.insert(W(0), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
HybridGaussianFactorGraph gfg = fg.linearize(initial);
|
||||
HybridGaussianFactorGraph gfg = *fg.linearize(initial);
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
|
||||
HybridGaussianISAM inc;
|
||||
|
|
@ -444,7 +444,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
// The leg link did not move so we set the expected pose accordingly.
|
||||
initial.insert(W(1), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
gfg = fg.linearize(initial);
|
||||
gfg = *fg.linearize(initial);
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
|
||||
// Update without pruning
|
||||
|
|
@ -483,7 +483,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
initial.insert(Z(2), Pose2(2.0, 2.0, 0.0));
|
||||
initial.insert(W(2), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
gfg = fg.linearize(initial);
|
||||
gfg = *fg.linearize(initial);
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
|
||||
// Now we prune!
|
||||
|
|
@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
// The MHS at this point should be a 2 level tree on (1, 2).
|
||||
// 1 has 2 choices, and 2 has 4 choices.
|
||||
inc.update(gfg);
|
||||
inc.prune(M(2), 2);
|
||||
inc.prune(2);
|
||||
|
||||
/*************** Run Round 4 ***************/
|
||||
// Add odometry factor with discrete modes.
|
||||
|
|
@ -526,12 +526,12 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
initial.insert(Z(3), Pose2(3.0, 2.0, 0.0));
|
||||
initial.insert(W(3), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
gfg = fg.linearize(initial);
|
||||
gfg = *fg.linearize(initial);
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
|
||||
// Keep pruning!
|
||||
inc.update(gfg);
|
||||
inc.prune(M(3), 3);
|
||||
inc.prune(3);
|
||||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
|
|
@ -1,272 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testHybridLookupDAG.cpp
|
||||
* @date Aug, 2022
|
||||
* @author Shangjie Xue
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
TEST(HybridLookupTable, basics) {
|
||||
// create a conditional gaussian node
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 2;
|
||||
S1(0, 1) = 3;
|
||||
S1(1, 1) = 4;
|
||||
|
||||
Matrix S2(2, 2);
|
||||
S2(0, 0) = 6;
|
||||
S2(1, 0) = 0.2;
|
||||
S2(0, 1) = 8;
|
||||
S2(1, 1) = 0.4;
|
||||
|
||||
Matrix R1(2, 2);
|
||||
R1(0, 0) = 0.1;
|
||||
R1(1, 0) = 0.3;
|
||||
R1(0, 1) = 0.0;
|
||||
R1(1, 1) = 0.34;
|
||||
|
||||
Matrix R2(2, 2);
|
||||
R2(0, 0) = 0.1;
|
||||
R2(1, 0) = 0.3;
|
||||
R2(0, 1) = 0.0;
|
||||
R2(1, 1) = 0.34;
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(0.5, 0.2);
|
||||
|
||||
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
|
||||
X(2), S1, model),
|
||||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||
X(2), S2, model);
|
||||
|
||||
// Create decision tree
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
|
||||
|
||||
HybridConditional hc(mixtureFactor);
|
||||
|
||||
GaussianMixture::Conditionals conditional2 =
|
||||
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
|
||||
|
||||
DiscreteValues dv;
|
||||
dv[1] = 1;
|
||||
|
||||
VectorValues cv;
|
||||
cv.insert(X(2), Vector2(0.0, 0.0));
|
||||
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
// std::cout << conditional2(values).markdown();
|
||||
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
|
||||
EXPECT(conditional2(dv) == conditionals(dv));
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
// hlt.argmaxInPlace(&hv);
|
||||
|
||||
HybridLookupDAG dag;
|
||||
dag.push_back(hlt);
|
||||
dag.argmax(hv);
|
||||
|
||||
// HybridBayesNet hbn;
|
||||
// hbn.push_back(hc);
|
||||
// hbn.optimize();
|
||||
}
|
||||
|
||||
TEST(HybridLookupTable, hybrid_argmax) {
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 0;
|
||||
S1(0, 1) = 0;
|
||||
S1(1, 1) = 1;
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional0 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
|
||||
conditional1 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
|
||||
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
|
||||
|
||||
HybridConditional hc(mixtureFactor);
|
||||
|
||||
DiscreteValues dv;
|
||||
dv[1] = 1;
|
||||
VectorValues cv;
|
||||
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
hlt.argmaxInPlace(&hv);
|
||||
|
||||
EXPECT(assert_equal(hv.at(X(1)), d2));
|
||||
}
|
||||
|
||||
TEST(HybridLookupTable, discrete_argmax) {
|
||||
DiscreteKey X(0, 2), Y(1, 2);
|
||||
|
||||
auto conditional = boost::make_shared<DiscreteConditional>(X | Y = "0/1 3/2");
|
||||
|
||||
HybridConditional hc(conditional);
|
||||
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
DiscreteValues dv;
|
||||
dv[1] = 0;
|
||||
VectorValues cv;
|
||||
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
hlt.argmaxInPlace(&hv);
|
||||
|
||||
EXPECT(assert_equal(hv.atDiscrete(0), 1));
|
||||
|
||||
DecisionTreeFactor f1(X, "2 3");
|
||||
auto conditional2 = boost::make_shared<DiscreteConditional>(1, f1);
|
||||
|
||||
HybridConditional hc2(conditional2);
|
||||
|
||||
HybridLookupTable hlt2(hc2);
|
||||
|
||||
HybridValues hv2;
|
||||
|
||||
hlt2.argmaxInPlace(&hv2);
|
||||
|
||||
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
|
||||
}
|
||||
|
||||
TEST(HybridLookupTable, gaussian_argmax) {
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 0;
|
||||
S1(0, 1) = 0;
|
||||
S1(1, 1) = 1;
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
|
||||
|
||||
HybridConditional hc(conditional);
|
||||
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
DiscreteValues dv;
|
||||
// dv[1]=0;
|
||||
VectorValues cv;
|
||||
cv.insert(X(2), d2);
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
hlt.argmaxInPlace(&hv);
|
||||
|
||||
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
|
||||
}
|
||||
|
||||
TEST(HybridLookupDAG, argmax) {
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 0;
|
||||
S1(0, 1) = 0;
|
||||
S1(1, 1) = 1;
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional0 =
|
||||
boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
|
||||
conditional1 =
|
||||
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
|
||||
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
|
||||
HybridConditional hc2(mixtureFactor);
|
||||
HybridLookupTable hlt2(hc2);
|
||||
|
||||
auto conditional2 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
|
||||
|
||||
HybridConditional hc1(conditional2);
|
||||
HybridLookupTable hlt1(hc1);
|
||||
|
||||
DecisionTreeFactor f1(m1, "2 3");
|
||||
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1, f1);
|
||||
|
||||
HybridConditional hc3(discrete_conditional);
|
||||
HybridLookupTable hlt3(hc3);
|
||||
|
||||
HybridLookupDAG dag;
|
||||
dag.push_back(hlt1);
|
||||
dag.push_back(hlt2);
|
||||
dag.push_back(hlt3);
|
||||
auto hv = dag.argmax();
|
||||
|
||||
EXPECT(assert_equal(hv.atDiscrete(1), 1));
|
||||
EXPECT(assert_equal(hv.at(X(2)), d2));
|
||||
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -60,7 +60,7 @@ TEST(HybridFactorGraph, GaussianFactorGraph) {
|
|||
Values linearizationPoint;
|
||||
linearizationPoint.insert<double>(X(0), 0);
|
||||
|
||||
HybridGaussianFactorGraph ghfg = fg.linearize(linearizationPoint);
|
||||
HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint);
|
||||
|
||||
// Add a factor to the GaussianFactorGraph
|
||||
ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5)));
|
||||
|
|
@ -139,7 +139,7 @@ TEST(HybridGaussianFactorGraph, Resize) {
|
|||
linearizationPoint.insert<double>(X(1), 1);
|
||||
|
||||
// Generate `HybridGaussianFactorGraph` by linearizing
|
||||
HybridGaussianFactorGraph gfg = nhfg.linearize(linearizationPoint);
|
||||
HybridGaussianFactorGraph gfg = *nhfg.linearize(linearizationPoint);
|
||||
|
||||
EXPECT_LONGS_EQUAL(gfg.size(), 3);
|
||||
|
||||
|
|
@ -147,6 +147,32 @@ TEST(HybridGaussianFactorGraph, Resize) {
|
|||
EXPECT_LONGS_EQUAL(gfg.size(), 0);
|
||||
}
|
||||
|
||||
/***************************************************************************
|
||||
* Test that the MixtureFactor reports correctly if the number of continuous
|
||||
* keys provided do not match the keys in the factors.
|
||||
*/
|
||||
TEST(HybridGaussianFactorGraph, MixtureFactor) {
|
||||
auto nonlinearFactor = boost::make_shared<BetweenFactor<double>>(
|
||||
X(0), X(1), 0.0, Isotropic::Sigma(1, 0.1));
|
||||
auto discreteFactor = boost::make_shared<DecisionTreeFactor>();
|
||||
|
||||
auto noise_model = noiseModel::Isotropic::Sigma(1, 1.0);
|
||||
auto still = boost::make_shared<MotionModel>(X(0), X(1), 0.0, noise_model),
|
||||
moving = boost::make_shared<MotionModel>(X(0), X(1), 1.0, noise_model);
|
||||
|
||||
std::vector<MotionModel::shared_ptr> components = {still, moving};
|
||||
|
||||
// Check for exception when number of continuous keys are under-specified.
|
||||
KeyVector contKeys = {X(0)};
|
||||
THROWS_EXCEPTION(boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components));
|
||||
|
||||
// Check for exception when number of continuous keys are too many.
|
||||
contKeys = {X(0), X(1), X(2)};
|
||||
THROWS_EXCEPTION(boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components));
|
||||
}
|
||||
|
||||
/*****************************************************************************
|
||||
* Test push_back on HFG makes the correct distinction.
|
||||
*/
|
||||
|
|
@ -224,7 +250,7 @@ TEST(HybridFactorGraph, Linearization) {
|
|||
|
||||
// Linearize here:
|
||||
HybridGaussianFactorGraph actualLinearized =
|
||||
self.nonlinearFactorGraph.linearize(self.linearizationPoint);
|
||||
*self.nonlinearFactorGraph.linearize(self.linearizationPoint);
|
||||
|
||||
EXPECT_LONGS_EQUAL(7, actualLinearized.size());
|
||||
}
|
||||
|
|
@ -257,14 +283,6 @@ TEST(GaussianElimination, Eliminate_x1) {
|
|||
// Add first hybrid factor
|
||||
factors.push_back(self.linearizedFactorGraph[1]);
|
||||
|
||||
// TODO(Varun) remove this block since sum is no longer exposed.
|
||||
// // Check that sum works:
|
||||
// auto sum = factors.sum();
|
||||
// Assignment<Key> mode;
|
||||
// mode[M(1)] = 1;
|
||||
// auto actual = sum(mode); // Selects one of 2 modes.
|
||||
// EXPECT_LONGS_EQUAL(2, actual.size()); // Prior and motion model.
|
||||
|
||||
// Eliminate x1
|
||||
Ordering ordering;
|
||||
ordering += X(1);
|
||||
|
|
@ -289,15 +307,6 @@ TEST(HybridsGaussianElimination, Eliminate_x2) {
|
|||
factors.push_back(self.linearizedFactorGraph[1]); // involves m1
|
||||
factors.push_back(self.linearizedFactorGraph[2]); // involves m2
|
||||
|
||||
// TODO(Varun) remove this block since sum is no longer exposed.
|
||||
// // Check that sum works:
|
||||
// auto sum = factors.sum();
|
||||
// Assignment<Key> mode;
|
||||
// mode[M(1)] = 0;
|
||||
// mode[M(2)] = 1;
|
||||
// auto actual = sum(mode); // Selects one of 4 mode
|
||||
// combinations. EXPECT_LONGS_EQUAL(2, actual.size()); // 2 motion models.
|
||||
|
||||
// Eliminate x2
|
||||
Ordering ordering;
|
||||
ordering += X(2);
|
||||
|
|
@ -364,51 +373,10 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
|
|||
CHECK(discreteFactor);
|
||||
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
||||
EXPECT(discreteFactor->root_->isLeaf() == false);
|
||||
|
||||
// TODO(Varun) Test emplace_discrete
|
||||
}
|
||||
|
||||
// /*
|
||||
// ****************************************************************************/
|
||||
// /// Test the toDecisionTreeFactor method
|
||||
// TEST(HybridFactorGraph, ToDecisionTreeFactor) {
|
||||
// size_t K = 3;
|
||||
|
||||
// // Provide tight sigma values so that the errors are visibly different.
|
||||
// double between_sigma = 5e-8, prior_sigma = 1e-7;
|
||||
|
||||
// Switching self(K, between_sigma, prior_sigma);
|
||||
|
||||
// // Clear out discrete factors since sum() cannot hanldle those
|
||||
// HybridGaussianFactorGraph linearizedFactorGraph(
|
||||
// self.linearizedFactorGraph.gaussianGraph(), DiscreteFactorGraph(),
|
||||
// self.linearizedFactorGraph.dcGraph());
|
||||
|
||||
// auto decisionTreeFactor = linearizedFactorGraph.toDecisionTreeFactor();
|
||||
|
||||
// auto allAssignments =
|
||||
// DiscreteValues::CartesianProduct(linearizedFactorGraph.discreteKeys());
|
||||
|
||||
// // Get the error of the discrete assignment m1=0, m2=1.
|
||||
// double actual = (*decisionTreeFactor)(allAssignments[1]);
|
||||
|
||||
// /********************************************/
|
||||
// // Create equivalent factor graph for m1=0, m2=1
|
||||
// GaussianFactorGraph graph = linearizedFactorGraph.gaussianGraph();
|
||||
|
||||
// for (auto &p : linearizedFactorGraph.dcGraph()) {
|
||||
// if (auto mixture =
|
||||
// boost::dynamic_pointer_cast<DCGaussianMixtureFactor>(p)) {
|
||||
// graph.add((*mixture)(allAssignments[1]));
|
||||
// }
|
||||
// }
|
||||
|
||||
// VectorValues values = graph.optimize();
|
||||
// double expected = graph.probPrime(values);
|
||||
// /********************************************/
|
||||
// EXPECT_DOUBLES_EQUAL(expected, actual, 1e-12);
|
||||
// // REGRESSION:
|
||||
// EXPECT_DOUBLES_EQUAL(0.6125, actual, 1e-4);
|
||||
// }
|
||||
|
||||
/****************************************************************************
|
||||
* Test partial elimination
|
||||
*/
|
||||
|
|
@ -428,7 +396,6 @@ TEST(HybridFactorGraph, Partial_Elimination) {
|
|||
linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||
|
||||
CHECK(hybridBayesNet);
|
||||
// GTSAM_PRINT(*hybridBayesNet); // HybridBayesNet
|
||||
EXPECT_LONGS_EQUAL(3, hybridBayesNet->size());
|
||||
EXPECT(hybridBayesNet->at(0)->frontals() == KeyVector{X(1)});
|
||||
EXPECT(hybridBayesNet->at(0)->parents() == KeyVector({X(2), M(1)}));
|
||||
|
|
@ -438,7 +405,6 @@ TEST(HybridFactorGraph, Partial_Elimination) {
|
|||
EXPECT(hybridBayesNet->at(2)->parents() == KeyVector({M(1), M(2)}));
|
||||
|
||||
CHECK(remainingFactorGraph);
|
||||
// GTSAM_PRINT(*remainingFactorGraph); // HybridFactorGraph
|
||||
EXPECT_LONGS_EQUAL(3, remainingFactorGraph->size());
|
||||
EXPECT(remainingFactorGraph->at(0)->keys() == KeyVector({M(1)}));
|
||||
EXPECT(remainingFactorGraph->at(1)->keys() == KeyVector({M(2), M(1)}));
|
||||
|
|
@ -721,13 +687,8 @@ TEST(HybridFactorGraph, DefaultDecisionTree) {
|
|||
moving = boost::make_shared<PlanarMotionModel>(X(0), X(1), odometry,
|
||||
noise_model);
|
||||
std::vector<PlanarMotionModel::shared_ptr> motion_models = {still, moving};
|
||||
// TODO(Varun) Make a templated constructor for MixtureFactor which does this?
|
||||
std::vector<NonlinearFactor::shared_ptr> components;
|
||||
for (auto&& f : motion_models) {
|
||||
components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f));
|
||||
}
|
||||
fg.emplace_hybrid<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components);
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, motion_models);
|
||||
|
||||
// Add Range-Bearing measurements to from X0 to L0 and X1 to L1.
|
||||
// create a noise model for the landmark measurements
|
||||
|
|
@ -757,7 +718,7 @@ TEST(HybridFactorGraph, DefaultDecisionTree) {
|
|||
ordering += X(0);
|
||||
ordering += X(1);
|
||||
|
||||
HybridGaussianFactorGraph linearized = fg.linearize(initialEstimate);
|
||||
HybridGaussianFactorGraph linearized = *fg.linearize(initialEstimate);
|
||||
gtsam::HybridBayesNet::shared_ptr hybridBayesNet;
|
||||
gtsam::HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,586 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testHybridNonlinearISAM.cpp
|
||||
* @brief Unit tests for nonlinear incremental inference
|
||||
* @author Varun Agrawal, Fan Jiang, Frank Dellaert
|
||||
* @date Jan 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
||||
#include <gtsam/linear/GaussianBayesNet.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
#include <gtsam/nonlinear/PriorFactor.h>
|
||||
#include <gtsam/sam/BearingRangeFactor.h>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "Switching.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::L;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::W;
|
||||
using symbol_shorthand::X;
|
||||
using symbol_shorthand::Y;
|
||||
using symbol_shorthand::Z;
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test if we can perform elimination incrementally.
|
||||
TEST(HybridNonlinearISAM, IncrementalElimination) {
|
||||
Switching switching(3);
|
||||
HybridNonlinearISAM isam;
|
||||
HybridNonlinearFactorGraph graph1;
|
||||
Values initial;
|
||||
|
||||
// Create initial factor graph
|
||||
// * * *
|
||||
// | | |
|
||||
// X1 -*- X2 -*- X3
|
||||
// \*-M1-*/
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X1)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X1, X2 | M1)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(2)); // P(X2, X3 | M2)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M1)
|
||||
|
||||
initial.insert<double>(X(1), 1);
|
||||
initial.insert<double>(X(2), 2);
|
||||
initial.insert<double>(X(3), 3);
|
||||
|
||||
// Run update step
|
||||
isam.update(graph1, initial);
|
||||
|
||||
// Check that after update we have 3 hybrid Bayes net nodes:
|
||||
// P(X1 | X2, M1) and P(X2, X3 | M1, M2), P(M1, M2)
|
||||
HybridGaussianISAM bayesTree = isam.bayesTree();
|
||||
EXPECT_LONGS_EQUAL(3, bayesTree.size());
|
||||
EXPECT(bayesTree[X(1)]->conditional()->frontals() == KeyVector{X(1)});
|
||||
EXPECT(bayesTree[X(1)]->conditional()->parents() == KeyVector({X(2), M(1)}));
|
||||
EXPECT(bayesTree[X(2)]->conditional()->frontals() == KeyVector({X(2), X(3)}));
|
||||
EXPECT(bayesTree[X(2)]->conditional()->parents() == KeyVector({M(1), M(2)}));
|
||||
|
||||
/********************************************************/
|
||||
// New factor graph for incremental update.
|
||||
HybridNonlinearFactorGraph graph2;
|
||||
initial = Values();
|
||||
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X2)
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X3)
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M1, M2)
|
||||
|
||||
isam.update(graph2, initial);
|
||||
|
||||
bayesTree = isam.bayesTree();
|
||||
// Check that after the second update we have
|
||||
// 1 additional hybrid Bayes net node:
|
||||
// P(X2, X3 | M1, M2)
|
||||
EXPECT_LONGS_EQUAL(3, bayesTree.size());
|
||||
EXPECT(bayesTree[X(3)]->conditional()->frontals() == KeyVector({X(2), X(3)}));
|
||||
EXPECT(bayesTree[X(3)]->conditional()->parents() == KeyVector({M(1), M(2)}));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test if we can incrementally do the inference
|
||||
TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||
Switching switching(3);
|
||||
HybridNonlinearISAM isam;
|
||||
HybridNonlinearFactorGraph graph1;
|
||||
Values initial;
|
||||
|
||||
// Create initial factor graph
|
||||
// * * *
|
||||
// | | |
|
||||
// X1 -*- X2 -*- X3
|
||||
// | |
|
||||
// *-M1 - * - M2
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X1)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X1, X2 | M1)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X2)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M1)
|
||||
|
||||
initial.insert<double>(X(1), 1);
|
||||
initial.insert<double>(X(2), 2);
|
||||
|
||||
// Run update step
|
||||
isam.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = isam.bayesTree();
|
||||
|
||||
auto discreteConditional_m1 =
|
||||
bayesTree[M(1)]->conditional()->asDiscreteConditional();
|
||||
EXPECT(discreteConditional_m1->keys() == KeyVector({M(1)}));
|
||||
|
||||
/********************************************************/
|
||||
// New factor graph for incremental update.
|
||||
HybridNonlinearFactorGraph graph2;
|
||||
initial = Values();
|
||||
|
||||
initial.insert<double>(X(3), 3);
|
||||
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(2)); // P(X2, X3 | M2)
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X3)
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M1, M2)
|
||||
|
||||
isam.update(graph2, initial);
|
||||
bayesTree = isam.bayesTree();
|
||||
|
||||
/********************************************************/
|
||||
// Run batch elimination so we can compare results.
|
||||
Ordering ordering;
|
||||
ordering += X(1);
|
||||
ordering += X(2);
|
||||
ordering += X(3);
|
||||
|
||||
// Now we calculate the actual factors using full elimination
|
||||
HybridBayesTree::shared_ptr expectedHybridBayesTree;
|
||||
HybridGaussianFactorGraph::shared_ptr expectedRemainingGraph;
|
||||
std::tie(expectedHybridBayesTree, expectedRemainingGraph) =
|
||||
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
|
||||
|
||||
// The densities on X(1) should be the same
|
||||
auto x1_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||
bayesTree[X(1)]->conditional()->inner());
|
||||
auto actual_x1_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||
(*expectedHybridBayesTree)[X(1)]->conditional()->inner());
|
||||
EXPECT(assert_equal(*x1_conditional, *actual_x1_conditional));
|
||||
|
||||
// The densities on X(2) should be the same
|
||||
auto x2_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||
bayesTree[X(2)]->conditional()->inner());
|
||||
auto actual_x2_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||
(*expectedHybridBayesTree)[X(2)]->conditional()->inner());
|
||||
EXPECT(assert_equal(*x2_conditional, *actual_x2_conditional));
|
||||
|
||||
// The densities on X(3) should be the same
|
||||
auto x3_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||
bayesTree[X(3)]->conditional()->inner());
|
||||
auto actual_x3_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||
(*expectedHybridBayesTree)[X(2)]->conditional()->inner());
|
||||
EXPECT(assert_equal(*x3_conditional, *actual_x3_conditional));
|
||||
|
||||
// We only perform manual continuous elimination for 0,0.
|
||||
// The other discrete probabilities on M(2) are calculated the same way
|
||||
Ordering discrete_ordering;
|
||||
discrete_ordering += M(1);
|
||||
discrete_ordering += M(2);
|
||||
HybridBayesTree::shared_ptr discreteBayesTree =
|
||||
expectedRemainingGraph->eliminateMultifrontal(discrete_ordering);
|
||||
|
||||
DiscreteValues m00;
|
||||
m00[M(1)] = 0, m00[M(2)] = 0;
|
||||
DiscreteConditional decisionTree =
|
||||
*(*discreteBayesTree)[M(2)]->conditional()->asDiscreteConditional();
|
||||
double m00_prob = decisionTree(m00);
|
||||
|
||||
auto discreteConditional =
|
||||
bayesTree[M(2)]->conditional()->asDiscreteConditional();
|
||||
|
||||
// Test if the probability values are as expected with regression tests.
|
||||
DiscreteValues assignment;
|
||||
EXPECT(assert_equal(m00_prob, 0.0619233, 1e-5));
|
||||
assignment[M(1)] = 0;
|
||||
assignment[M(2)] = 0;
|
||||
EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(1)] = 1;
|
||||
assignment[M(2)] = 0;
|
||||
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(1)] = 0;
|
||||
assignment[M(2)] = 1;
|
||||
EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(1)] = 1;
|
||||
assignment[M(2)] = 1;
|
||||
EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5));
|
||||
|
||||
// Check if the clique conditional generated from incremental elimination
|
||||
// matches that of batch elimination.
|
||||
auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
|
||||
auto expectedConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
(*expectedChordal)[M(2)]->conditional()->inner());
|
||||
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
bayesTree[M(2)]->conditional()->inner());
|
||||
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test if we can approximately do the inference
|
||||
TEST(HybridNonlinearISAM, Approx_inference) {
|
||||
Switching switching(4);
|
||||
HybridNonlinearISAM incrementalHybrid;
|
||||
HybridNonlinearFactorGraph graph1;
|
||||
Values initial;
|
||||
|
||||
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the Gaussian factors, 1 prior on X(1),
|
||||
// 3 measurements on X(2), X(3), X(4)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(0));
|
||||
for (size_t i = 4; i <= 7; i++) {
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
||||
initial.insert<double>(X(i - 3), i - 3);
|
||||
}
|
||||
|
||||
// Create ordering.
|
||||
Ordering ordering;
|
||||
for (size_t j = 1; j <= 4; j++) {
|
||||
ordering += X(j);
|
||||
}
|
||||
|
||||
// Now we calculate the actual factors using full elimination
|
||||
HybridBayesTree::shared_ptr unprunedHybridBayesTree;
|
||||
HybridGaussianFactorGraph::shared_ptr unprunedRemainingGraph;
|
||||
std::tie(unprunedHybridBayesTree, unprunedRemainingGraph) =
|
||||
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
|
||||
|
||||
size_t maxNrLeaves = 5;
|
||||
incrementalHybrid.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(maxNrLeaves);
|
||||
|
||||
/*
|
||||
unpruned factor is:
|
||||
Choice(m3)
|
||||
0 Choice(m2)
|
||||
0 0 Choice(m1)
|
||||
0 0 0 Leaf 0.11267528
|
||||
0 0 1 Leaf 0.18576102
|
||||
0 1 Choice(m1)
|
||||
0 1 0 Leaf 0.18754662
|
||||
0 1 1 Leaf 0.30623871
|
||||
1 Choice(m2)
|
||||
1 0 Choice(m1)
|
||||
1 0 0 Leaf 0.18576102
|
||||
1 0 1 Leaf 0.30622428
|
||||
1 1 Choice(m1)
|
||||
1 1 0 Leaf 0.30623871
|
||||
1 1 1 Leaf 0.5
|
||||
|
||||
pruned factors is:
|
||||
Choice(m3)
|
||||
0 Choice(m2)
|
||||
0 0 Leaf 0
|
||||
0 1 Choice(m1)
|
||||
0 1 0 Leaf 0.18754662
|
||||
0 1 1 Leaf 0.30623871
|
||||
1 Choice(m2)
|
||||
1 0 Choice(m1)
|
||||
1 0 0 Leaf 0
|
||||
1 0 1 Leaf 0.30622428
|
||||
1 1 Choice(m1)
|
||||
1 1 0 Leaf 0.30623871
|
||||
1 1 1 Leaf 0.5
|
||||
*/
|
||||
|
||||
auto discreteConditional_m1 = *dynamic_pointer_cast<DiscreteConditional>(
|
||||
bayesTree[M(1)]->conditional()->inner());
|
||||
EXPECT(discreteConditional_m1.keys() == KeyVector({M(1), M(2), M(3)}));
|
||||
|
||||
// Get the number of elements which are greater than 0.
|
||||
auto count = [](const double &value, int count) {
|
||||
return value > 0 ? count + 1 : count;
|
||||
};
|
||||
// Check that the number of leaves after pruning is 5.
|
||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m1.fold(count, 0));
|
||||
|
||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||
// bayes net, at the same positions.
|
||||
auto &unprunedLastDensity = *dynamic_pointer_cast<GaussianMixture>(
|
||||
unprunedHybridBayesTree->clique(X(4))->conditional()->inner());
|
||||
auto &lastDensity = *dynamic_pointer_cast<GaussianMixture>(
|
||||
bayesTree[X(4)]->conditional()->inner());
|
||||
|
||||
std::vector<std::pair<DiscreteValues, double>> assignments =
|
||||
discreteConditional_m1.enumerate();
|
||||
// Loop over all assignments and check the pruned components
|
||||
for (auto &&av : assignments) {
|
||||
const DiscreteValues &assignment = av.first;
|
||||
const double value = av.second;
|
||||
|
||||
if (value == 0.0) {
|
||||
EXPECT(lastDensity(assignment) == nullptr);
|
||||
} else {
|
||||
CHECK(lastDensity(assignment));
|
||||
EXPECT(assert_equal(*unprunedLastDensity(assignment),
|
||||
*lastDensity(assignment)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test approximate inference with an additional pruning step.
|
||||
TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||
Switching switching(5);
|
||||
HybridNonlinearISAM incrementalHybrid;
|
||||
HybridNonlinearFactorGraph graph1;
|
||||
Values initial;
|
||||
|
||||
/***** Run Round 1 *****/
|
||||
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the Gaussian factors, 1 prior on X(1),
|
||||
// 3 measurements on X(2), X(3), X(4)
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(0));
|
||||
initial.insert<double>(X(1), 1);
|
||||
for (size_t i = 5; i <= 7; i++) {
|
||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
||||
initial.insert<double>(X(i - 3), i - 3);
|
||||
}
|
||||
|
||||
// Run update with pruning
|
||||
size_t maxComponents = 5;
|
||||
incrementalHybrid.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
EXPECT_LONGS_EQUAL(4, bayesTree.size());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents());
|
||||
|
||||
/***** Run Round 2 *****/
|
||||
HybridGaussianFactorGraph graph2;
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // x4-x5
|
||||
graph2.push_back(switching.nonlinearFactorGraph.at(8)); // x5 measurement
|
||||
initial = Values();
|
||||
initial.insert<double>(X(5), 5);
|
||||
|
||||
// Run update with pruning a second time.
|
||||
incrementalHybrid.update(graph2, initial);
|
||||
bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
CHECK_EQUAL(5, bayesTree.size());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, bayesTree[X(5)]->conditional()->asMixture()->nrComponents());
|
||||
}
|
||||
|
||||
/* ************************************************************************/
|
||||
// A GTSAM-only test for running inference on a single-legged robot.
|
||||
// The leg links are represented by the chain X-Y-Z-W, where X is the base and
|
||||
// W is the foot.
|
||||
// We use BetweenFactor<Pose2> as constraints between each of the poses.
|
||||
TEST(HybridNonlinearISAM, NonTrivial) {
|
||||
/*************** Run Round 1 ***************/
|
||||
HybridNonlinearFactorGraph fg;
|
||||
HybridNonlinearISAM inc;
|
||||
|
||||
// Add a prior on pose x1 at the origin.
|
||||
// A prior factor consists of a mean and
|
||||
// a noise model (covariance matrix)
|
||||
Pose2 prior(0.0, 0.0, 0.0); // prior mean is at origin
|
||||
auto priorNoise = noiseModel::Diagonal::Sigmas(
|
||||
Vector3(0.3, 0.3, 0.1)); // 30cm std on x,y, 0.1 rad on theta
|
||||
fg.emplace_nonlinear<PriorFactor<Pose2>>(X(0), prior, priorNoise);
|
||||
|
||||
// create a noise model for the landmark measurements
|
||||
auto poseNoise = noiseModel::Isotropic::Sigma(3, 0.1);
|
||||
|
||||
// We model a robot's single leg as X - Y - Z - W
|
||||
// where X is the base link and W is the foot link.
|
||||
|
||||
// Add connecting poses similar to PoseFactors in GTD
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(0), Y(0), Pose2(0, 1.0, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(0), Z(0), Pose2(0, 1.0, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(0), W(0), Pose2(0, 1.0, 0),
|
||||
poseNoise);
|
||||
|
||||
// Create initial estimate
|
||||
Values initial;
|
||||
initial.insert(X(0), Pose2(0.0, 0.0, 0.0));
|
||||
initial.insert(Y(0), Pose2(0.0, 1.0, 0.0));
|
||||
initial.insert(Z(0), Pose2(0.0, 2.0, 0.0));
|
||||
initial.insert(W(0), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
// Don't run update now since we don't have discrete variables involved.
|
||||
|
||||
using PlanarMotionModel = BetweenFactor<Pose2>;
|
||||
|
||||
/*************** Run Round 2 ***************/
|
||||
// Add odometry factor with discrete modes.
|
||||
Pose2 odometry(1.0, 0.0, 0.0);
|
||||
KeyVector contKeys = {W(0), W(1)};
|
||||
auto noise_model = noiseModel::Isotropic::Sigma(3, 1.0);
|
||||
auto still = boost::make_shared<PlanarMotionModel>(W(0), W(1), Pose2(0, 0, 0),
|
||||
noise_model),
|
||||
moving = boost::make_shared<PlanarMotionModel>(W(0), W(1), odometry,
|
||||
noise_model);
|
||||
std::vector<PlanarMotionModel::shared_ptr> components = {moving, still};
|
||||
auto mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components);
|
||||
fg.push_back(mixtureFactor);
|
||||
|
||||
// Add equivalent of ImuFactor
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(0), X(1), Pose2(1.0, 0.0, 0),
|
||||
poseNoise);
|
||||
// PoseFactors-like at k=1
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(1), Y(1), Pose2(0, 1, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(1), Z(1), Pose2(0, 1, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(1), W(1), Pose2(-1, 1, 0),
|
||||
poseNoise);
|
||||
|
||||
initial.insert(X(1), Pose2(1.0, 0.0, 0.0));
|
||||
initial.insert(Y(1), Pose2(1.0, 1.0, 0.0));
|
||||
initial.insert(Z(1), Pose2(1.0, 2.0, 0.0));
|
||||
// The leg link did not move so we set the expected pose accordingly.
|
||||
initial.insert(W(1), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
// Update without pruning
|
||||
// The result is a HybridBayesNet with 1 discrete variable M(1).
|
||||
// P(X | measurements) = P(W0|Z0, W1, M1) P(Z0|Y0, W1, M1) P(Y0|X0, W1, M1)
|
||||
// P(X0 | X1, W1, M1) P(W1|Z1, X1, M1) P(Z1|Y1, X1, M1)
|
||||
// P(Y1 | X1, M1)P(X1 | M1)P(M1)
|
||||
// The MHS tree is a 1 level tree for time indices (1,) with 2 leaves.
|
||||
inc.update(fg, initial);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
||||
/*************** Run Round 3 ***************/
|
||||
// Add odometry factor with discrete modes.
|
||||
contKeys = {W(1), W(2)};
|
||||
still = boost::make_shared<PlanarMotionModel>(W(1), W(2), Pose2(0, 0, 0),
|
||||
noise_model);
|
||||
moving =
|
||||
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry, noise_model);
|
||||
components = {moving, still};
|
||||
mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components);
|
||||
fg.push_back(mixtureFactor);
|
||||
|
||||
// Add equivalent of ImuFactor
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(1), X(2), Pose2(1.0, 0.0, 0),
|
||||
poseNoise);
|
||||
// PoseFactors-like at k=1
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(2), Y(2), Pose2(0, 1, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(2), Z(2), Pose2(0, 1, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(2), W(2), Pose2(-2, 1, 0),
|
||||
poseNoise);
|
||||
|
||||
initial.insert(X(2), Pose2(2.0, 0.0, 0.0));
|
||||
initial.insert(Y(2), Pose2(2.0, 1.0, 0.0));
|
||||
initial.insert(Z(2), Pose2(2.0, 2.0, 0.0));
|
||||
initial.insert(W(2), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
// Now we prune!
|
||||
// P(X | measurements) = P(W0|Z0, W1, M1) P(Z0|Y0, W1, M1) P(Y0|X0, W1, M1)
|
||||
// P(X0 | X1, W1, M1) P(W1|W2, Z1, X1, M1, M2)
|
||||
// P(Z1| W2, Y1, X1, M1, M2) P(Y1 | W2, X1, M1, M2)
|
||||
// P(X1 | W2, X2, M1, M2) P(W2|Z2, X2, M1, M2)
|
||||
// P(Z2|Y2, X2, M1, M2) P(Y2 | X2, M1, M2)
|
||||
// P(X2 | M1, M2) P(M1, M2)
|
||||
// The MHS at this point should be a 2 level tree on (1, 2).
|
||||
// 1 has 2 choices, and 2 has 4 choices.
|
||||
inc.update(fg, initial);
|
||||
inc.prune(2);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
||||
/*************** Run Round 4 ***************/
|
||||
// Add odometry factor with discrete modes.
|
||||
contKeys = {W(2), W(3)};
|
||||
still = boost::make_shared<PlanarMotionModel>(W(2), W(3), Pose2(0, 0, 0),
|
||||
noise_model);
|
||||
moving =
|
||||
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry, noise_model);
|
||||
components = {moving, still};
|
||||
mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components);
|
||||
fg.push_back(mixtureFactor);
|
||||
|
||||
// Add equivalent of ImuFactor
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(2), X(3), Pose2(1.0, 0.0, 0),
|
||||
poseNoise);
|
||||
// PoseFactors-like at k=3
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(X(3), Y(3), Pose2(0, 1, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Y(3), Z(3), Pose2(0, 1, 0),
|
||||
poseNoise);
|
||||
fg.emplace_nonlinear<BetweenFactor<Pose2>>(Z(3), W(3), Pose2(-3, 1, 0),
|
||||
poseNoise);
|
||||
|
||||
initial.insert(X(3), Pose2(3.0, 0.0, 0.0));
|
||||
initial.insert(Y(3), Pose2(3.0, 1.0, 0.0));
|
||||
initial.insert(Z(3), Pose2(3.0, 2.0, 0.0));
|
||||
initial.insert(W(3), Pose2(0.0, 3.0, 0.0));
|
||||
|
||||
// Keep pruning!
|
||||
inc.update(fg, initial);
|
||||
inc.prune(3);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
||||
HybridGaussianISAM bayesTree = inc.bayesTree();
|
||||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
DiscreteFactorGraph discreteGraph;
|
||||
discreteGraph.push_back(discreteTree);
|
||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||
|
||||
DiscreteValues expected_assignment;
|
||||
expected_assignment[M(1)] = 1;
|
||||
expected_assignment[M(2)] = 1;
|
||||
expected_assignment[M(3)] = 1;
|
||||
|
||||
EXPECT(assert_equal(expected_assignment, optimal_assignment));
|
||||
|
||||
// Test if pruning worked correctly by checking that
|
||||
// we only have 3 leaves in the last node.
|
||||
auto lastConditional = bayesTree[X(3)]->conditional()->asMixture();
|
||||
EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
|
|
@ -33,7 +33,6 @@ namespace gtsam {
|
|||
// Forward declarations
|
||||
template<class FACTOR> class FactorGraph;
|
||||
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
|
||||
class HybridBayesTreeClique;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/** clique statistics */
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ struct ConstructorTraversalData {
|
|||
typedef typename JunctionTree<BAYESTREE, GRAPH>::sharedNode sharedNode;
|
||||
|
||||
ConstructorTraversalData* const parentData;
|
||||
sharedNode myJTNode;
|
||||
sharedNode junctionTreeNode;
|
||||
FastVector<SymbolicConditional::shared_ptr> childSymbolicConditionals;
|
||||
FastVector<SymbolicFactor::shared_ptr> childSymbolicFactors;
|
||||
|
||||
|
|
@ -53,8 +53,9 @@ struct ConstructorTraversalData {
|
|||
// a traversal data structure with its own JT node, and create a child
|
||||
// pointer in its parent.
|
||||
ConstructorTraversalData myData = ConstructorTraversalData(&parentData);
|
||||
myData.myJTNode = boost::make_shared<Node>(node->key, node->factors);
|
||||
parentData.myJTNode->addChild(myData.myJTNode);
|
||||
myData.junctionTreeNode =
|
||||
boost::make_shared<Node>(node->key, node->factors);
|
||||
parentData.junctionTreeNode->addChild(myData.junctionTreeNode);
|
||||
return myData;
|
||||
}
|
||||
|
||||
|
|
@ -91,7 +92,7 @@ struct ConstructorTraversalData {
|
|||
myData.parentData->childSymbolicConditionals.push_back(myConditional);
|
||||
myData.parentData->childSymbolicFactors.push_back(mySeparatorFactor);
|
||||
|
||||
sharedNode node = myData.myJTNode;
|
||||
sharedNode node = myData.junctionTreeNode;
|
||||
const FastVector<SymbolicConditional::shared_ptr>& childConditionals =
|
||||
myData.childSymbolicConditionals;
|
||||
node->problemSize_ = (int) (myConditional->size() * symbolicFactors.size());
|
||||
|
|
@ -138,14 +139,14 @@ JunctionTree<BAYESTREE, GRAPH>::JunctionTree(
|
|||
typedef typename EliminationTree<ETREE_BAYESNET, ETREE_GRAPH>::Node ETreeNode;
|
||||
typedef ConstructorTraversalData<BAYESTREE, GRAPH, ETreeNode> Data;
|
||||
Data rootData(0);
|
||||
rootData.myJTNode = boost::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
||||
// the junction tree roots
|
||||
// Make a dummy node to gather the junction tree roots
|
||||
rootData.junctionTreeNode = boost::make_shared<typename Base::Node>();
|
||||
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
||||
Data::ConstructorTraversalVisitorPre,
|
||||
Data::ConstructorTraversalVisitorPostAlg2);
|
||||
|
||||
// Assign roots from the dummy node
|
||||
this->addChildrenAsRoots(rootData.myJTNode);
|
||||
this->addChildrenAsRoots(rootData.junctionTreeNode);
|
||||
|
||||
// Transfer remaining factors from elimination tree
|
||||
Base::remainingFactors_ = eliminationTree.remainingFactors();
|
||||
|
|
|
|||
|
|
@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
|
|||
EXPECT(equalsBinary(graph));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
TEST(Serialization, gaussian_bayes_net) {
|
||||
// Create an arbitrary Bayes Net
|
||||
GaussianBayesNet gbn;
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
|
||||
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
|
||||
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
|
||||
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
|
||||
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
|
||||
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
|
||||
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));
|
||||
|
||||
std::string serialized = serialize(gbn);
|
||||
GaussianBayesNet actual;
|
||||
deserialize(serialized, actual);
|
||||
EXPECT(assert_equal(gbn, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST (Serialization, gaussian_bayes_tree) {
|
||||
const Key x1=1, x2=2, x3=3, x4=4;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,136 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DsfTrackGenerator.cpp
|
||||
* @date October 2022
|
||||
* @author John Lambert
|
||||
* @brief Identifies connected components in the keypoint matches graph.
|
||||
*/
|
||||
|
||||
#include <gtsam/sfm/DsfTrackGenerator.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
namespace gtsfm {
|
||||
|
||||
typedef DSFMap<IndexPair> DSFMapIndexPair;
|
||||
|
||||
/// Generate the DSF to form tracks.
|
||||
static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
|
||||
DSFMapIndexPair dsf;
|
||||
|
||||
for (const auto& kv : matches) {
|
||||
const auto pair_indices = kv.first;
|
||||
const auto corr_indices = kv.second;
|
||||
|
||||
// Image pair is (i1,i2).
|
||||
size_t i1 = pair_indices.first;
|
||||
size_t i2 = pair_indices.second;
|
||||
for (size_t k = 0; k < corr_indices.rows(); k++) {
|
||||
// Measurement indices are found in a single matrix row, as (k1,k2).
|
||||
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
|
||||
// Unique key for DSF is (i,k), representing keypoint index in an image.
|
||||
dsf.merge({i1, k1}, {i2, k2});
|
||||
}
|
||||
}
|
||||
|
||||
return dsf;
|
||||
}
|
||||
|
||||
/// Generate a single track from a set of index pairs
|
||||
static SfmTrack2d trackFromIndexPairs(const std::set<IndexPair>& index_pair_set,
|
||||
const KeypointsVector& keypoints) {
|
||||
// Initialize track from measurements.
|
||||
SfmTrack2d track2d;
|
||||
|
||||
for (const auto& index_pair : index_pair_set) {
|
||||
// Camera index is represented by i, and measurement index is
|
||||
// represented by k.
|
||||
size_t i = index_pair.i();
|
||||
size_t k = index_pair.j();
|
||||
// Add measurement to this track.
|
||||
track2d.addMeasurement(i, keypoints[i].coordinates.row(k));
|
||||
}
|
||||
|
||||
return track2d;
|
||||
}
|
||||
|
||||
/// Generate tracks from the DSF.
|
||||
static std::vector<SfmTrack2d> tracksFromDSF(const DSFMapIndexPair& dsf,
|
||||
const KeypointsVector& keypoints) {
|
||||
const std::map<IndexPair, std::set<IndexPair> > key_sets = dsf.sets();
|
||||
|
||||
// Return immediately if no sets were found.
|
||||
if (key_sets.empty()) return {};
|
||||
|
||||
// Create a list of tracks.
|
||||
// Each track will be represented as a list of (camera_idx, measurements).
|
||||
std::vector<SfmTrack2d> tracks2d;
|
||||
for (const auto& kv : key_sets) {
|
||||
// Initialize track from measurements.
|
||||
SfmTrack2d track2d = trackFromIndexPairs(kv.second, keypoints);
|
||||
tracks2d.emplace_back(track2d);
|
||||
}
|
||||
return tracks2d;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates a list of tracks from 2d point correspondences.
|
||||
*
|
||||
* Creates a disjoint-set forest (DSF) and 2d tracks from pairwise matches.
|
||||
* We create a singleton for union-find set elements from camera index of a
|
||||
* detection and the index of that detection in that camera's keypoint list,
|
||||
* i.e. (i,k).
|
||||
*
|
||||
* @param Map from (i1,i2) image pair indices to (K,2) matrix, for K
|
||||
* correspondence indices, from each image.
|
||||
* @param Length-N list of keypoints, for N images/cameras.
|
||||
*/
|
||||
std::vector<SfmTrack2d> tracksFromPairwiseMatches(
|
||||
const MatchIndicesMap& matches, const KeypointsVector& keypoints,
|
||||
bool verbose) {
|
||||
// Generate the DSF to form tracks.
|
||||
if (verbose) std::cout << "[SfmTrack2d] Starting Union-Find..." << std::endl;
|
||||
DSFMapIndexPair dsf = generateDSF(matches);
|
||||
if (verbose) std::cout << "[SfmTrack2d] Union-Find Complete" << std::endl;
|
||||
|
||||
std::vector<SfmTrack2d> tracks2d = tracksFromDSF(dsf, keypoints);
|
||||
|
||||
// Filter out erroneous tracks that had repeated measurements within the
|
||||
// same image. This is an expected result from an incorrect correspondence
|
||||
// slipping through.
|
||||
std::vector<SfmTrack2d> validTracks;
|
||||
std::copy_if(
|
||||
tracks2d.begin(), tracks2d.end(), std::back_inserter(validTracks),
|
||||
[](const SfmTrack2d& track2d) { return track2d.hasUniqueCameras(); });
|
||||
|
||||
if (verbose) {
|
||||
size_t erroneous_track_count = tracks2d.size() - validTracks.size();
|
||||
double erroneous_percentage = static_cast<float>(erroneous_track_count) /
|
||||
static_cast<float>(tracks2d.size()) * 100;
|
||||
|
||||
std::cout << std::fixed << std::setprecision(2);
|
||||
std::cout << "DSF Union-Find: " << erroneous_percentage;
|
||||
std::cout << "% of tracks discarded from multiple obs. in a single image."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// TODO(johnwlambert): return the Transitivity failure percentage here.
|
||||
return tracks2d;
|
||||
}
|
||||
|
||||
} // namespace gtsfm
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DsfTrackGenerator.h
|
||||
* @date July 2022
|
||||
* @author John Lambert
|
||||
* @brief Identifies connected components in the keypoint matches graph.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <gtsam/base/DSFMap.h>
|
||||
#include <gtsam/sfm/SfmTrack.h>
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
namespace gtsfm {
|
||||
|
||||
typedef Eigen::MatrixX2i CorrespondenceIndices; // N x 2 array
|
||||
|
||||
// Output of detections in an image.
|
||||
// Coordinate system convention:
|
||||
// 1. The x coordinate denotes the horizontal direction (+ve direction towards
|
||||
// the right).
|
||||
// 2. The y coordinate denotes the vertical direction (+ve direction downwards).
|
||||
// 3. Origin is at the top left corner of the image.
|
||||
struct Keypoints {
|
||||
// The (x, y) coordinates of the features, of shape Nx2.
|
||||
Eigen::MatrixX2d coordinates;
|
||||
|
||||
// Optional scale of the detections, of shape N.
|
||||
// Note: gtsam::Vector is typedef'd for Eigen::VectorXd.
|
||||
boost::optional<gtsam::Vector> scales;
|
||||
|
||||
/// Optional confidences/responses for each detection, of shape N.
|
||||
boost::optional<gtsam::Vector> responses;
|
||||
|
||||
Keypoints(const Eigen::MatrixX2d& coordinates)
|
||||
: coordinates(coordinates){}; // boost::none
|
||||
};
|
||||
|
||||
using KeypointsVector = std::vector<Keypoints>;
|
||||
// Mapping from each image pair to (N,2) array representing indices of matching
|
||||
// keypoints.
|
||||
using MatchIndicesMap = std::map<IndexPair, CorrespondenceIndices>;
|
||||
|
||||
/**
|
||||
* @brief Creates a list of tracks from 2d point correspondences.
|
||||
*
|
||||
* Creates a disjoint-set forest (DSF) and 2d tracks from pairwise matches.
|
||||
* We create a singleton for union-find set elements from camera index of a
|
||||
* detection and the index of that detection in that camera's keypoint list,
|
||||
* i.e. (i,k).
|
||||
*
|
||||
* @param Map from (i1,i2) image pair indices to (K,2) matrix, for K
|
||||
* correspondence indices, from each image.
|
||||
* @param Length-N list of keypoints, for N images/cameras.
|
||||
*/
|
||||
std::vector<SfmTrack2d> tracksFromPairwiseMatches(
|
||||
const MatchIndicesMap& matches, const KeypointsVector& keypoints,
|
||||
bool verbose = false);
|
||||
|
||||
} // namespace gtsfm
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/geometry/Point2.h>
|
||||
#include <gtsam/geometry/Point3.h>
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
|
@ -35,28 +36,26 @@ typedef std::pair<size_t, Point2> SfmMeasurement;
|
|||
typedef std::pair<size_t, size_t> SiftIndex;
|
||||
|
||||
/**
|
||||
* @brief An SfmTrack stores SfM measurements grouped in a track
|
||||
* @ingroup sfm
|
||||
* @brief Track containing 2D measurements associated with a single 3D point.
|
||||
* Note: Equivalent to gtsam.SfmTrack, but without the 3d measurement.
|
||||
* This class holds data temporarily before 3D point is initialized.
|
||||
*/
|
||||
struct GTSAM_EXPORT SfmTrack {
|
||||
Point3 p; ///< 3D position of the point
|
||||
float r, g, b; ///< RGB color of the 3D point
|
||||
|
||||
struct GTSAM_EXPORT SfmTrack2d {
|
||||
/// The 2D image projections (id,(u,v))
|
||||
std::vector<SfmMeasurement> measurements;
|
||||
|
||||
/// The feature descriptors
|
||||
/// The feature descriptors (optional)
|
||||
std::vector<SiftIndex> siftIndices;
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
explicit SfmTrack(float r = 0, float g = 0, float b = 0)
|
||||
: p(0, 0, 0), r(r), g(g), b(b) {}
|
||||
// Default constructor.
|
||||
SfmTrack2d() = default;
|
||||
|
||||
explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0,
|
||||
float b = 0)
|
||||
: p(pt), r(r), g(g), b(b) {}
|
||||
// Constructor from measurements.
|
||||
explicit SfmTrack2d(const std::vector<SfmMeasurement>& measurements)
|
||||
: measurements(measurements) {}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
|
|
@ -78,6 +77,70 @@ struct GTSAM_EXPORT SfmTrack {
|
|||
/// Get the SIFT feature index corresponding to the measurement at `idx`
|
||||
const SiftIndex& siftIndex(size_t idx) const { return siftIndices[idx]; }
|
||||
|
||||
/**
|
||||
* @brief Check that no two measurements are from the same camera.
|
||||
* @returns boolean result of the validation.
|
||||
*/
|
||||
bool hasUniqueCameras() const {
|
||||
std::vector<int> track_cam_indices;
|
||||
for (auto& measurement : measurements) {
|
||||
track_cam_indices.emplace_back(measurement.first);
|
||||
}
|
||||
auto i =
|
||||
std::adjacent_find(track_cam_indices.begin(), track_cam_indices.end());
|
||||
bool all_cameras_unique = (i == track_cam_indices.end());
|
||||
return all_cameras_unique;
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Vectorized Interface
|
||||
/// @{
|
||||
|
||||
/// @brief Return the measurements as a 2D matrix
|
||||
Eigen::MatrixX2d measurementMatrix() const {
|
||||
Eigen::MatrixX2d m(numberMeasurements(), 2);
|
||||
for (size_t i = 0; i < numberMeasurements(); i++) {
|
||||
m.row(i) = measurement(i).second;
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
/// @brief Return the camera indices of the measurements
|
||||
Eigen::VectorXi indexVector() const {
|
||||
Eigen::VectorXi v(numberMeasurements());
|
||||
for (size_t i = 0; i < numberMeasurements(); i++) {
|
||||
v(i) = measurement(i).first;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
using SfmTrack2dVector = std::vector<SfmTrack2d>;
|
||||
|
||||
/**
|
||||
* @brief An SfmTrack stores SfM measurements grouped in a track
|
||||
* @addtogroup sfm
|
||||
*/
|
||||
struct GTSAM_EXPORT SfmTrack : SfmTrack2d {
|
||||
Point3 p; ///< 3D position of the point
|
||||
float r, g, b; ///< RGB color of the 3D point
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
explicit SfmTrack(float r = 0, float g = 0, float b = 0)
|
||||
: p(0, 0, 0), r(r), g(g), b(b) {}
|
||||
|
||||
explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0,
|
||||
float b = 0)
|
||||
: p(pt), r(r), g(g), b(b) {}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Get 3D point
|
||||
const Point3& point3() const { return p; }
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,23 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
#include <gtsam/sfm/SfmTrack.h>
|
||||
class SfmTrack {
|
||||
class SfmTrack2d {
|
||||
std::vector<pair<size_t, gtsam::Point2>> measurements;
|
||||
|
||||
SfmTrack2d();
|
||||
SfmTrack2d(const std::vector<gtsam::SfmMeasurement>& measurements);
|
||||
size_t numberMeasurements() const;
|
||||
pair<size_t, gtsam::Point2> measurement(size_t idx) const;
|
||||
pair<size_t, size_t> siftIndex(size_t idx) const;
|
||||
void addMeasurement(size_t idx, const gtsam::Point2& m);
|
||||
gtsam::SfmMeasurement measurement(size_t idx) const;
|
||||
bool hasUniqueCameras() const;
|
||||
Eigen::MatrixX2d measurementMatrix() const;
|
||||
Eigen::VectorXi indexVector() const;
|
||||
};
|
||||
|
||||
virtual class SfmTrack : gtsam::SfmTrack2d {
|
||||
SfmTrack();
|
||||
SfmTrack(const gtsam::Point3& pt);
|
||||
const Point3& point3() const;
|
||||
|
|
@ -18,13 +31,6 @@ class SfmTrack {
|
|||
double g;
|
||||
double b;
|
||||
|
||||
std::vector<pair<size_t, gtsam::Point2>> measurements;
|
||||
|
||||
size_t numberMeasurements() const;
|
||||
pair<size_t, gtsam::Point2> measurement(size_t idx) const;
|
||||
pair<size_t, size_t> siftIndex(size_t idx) const;
|
||||
void addMeasurement(size_t idx, const gtsam::Point2& m);
|
||||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
|
||||
|
|
@ -32,6 +38,8 @@ class SfmTrack {
|
|||
bool equals(const gtsam::SfmTrack& expected, double tol) const;
|
||||
};
|
||||
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/sfm/SfmData.h>
|
||||
class SfmData {
|
||||
SfmData();
|
||||
|
|
@ -115,7 +123,7 @@ class BinaryMeasurementsRot3 {
|
|||
|
||||
#include <gtsam/sfm/ShonanAveraging.h>
|
||||
|
||||
// TODO(frank): copy/pasta below until we have integer template paremeters in
|
||||
// TODO(frank): copy/pasta below until we have integer template parameters in
|
||||
// wrap!
|
||||
|
||||
class ShonanAveragingParameters2 {
|
||||
|
|
@ -310,4 +318,38 @@ class TranslationRecovery {
|
|||
const gtsam::BinaryMeasurementsUnit3& relativeTranslations) const;
|
||||
};
|
||||
|
||||
namespace gtsfm {
|
||||
|
||||
#include <gtsam/sfm/DsfTrackGenerator.h>
|
||||
|
||||
class MatchIndicesMap {
|
||||
MatchIndicesMap();
|
||||
MatchIndicesMap(const gtsam::gtsfm::MatchIndicesMap& other);
|
||||
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
void clear();
|
||||
gtsam::gtsfm::CorrespondenceIndices at(const pair<size_t, size_t>& keypair) const;
|
||||
};
|
||||
|
||||
class Keypoints {
|
||||
Keypoints(const Eigen::MatrixX2d& coordinates);
|
||||
Eigen::MatrixX2d coordinates;
|
||||
};
|
||||
|
||||
class KeypointsVector {
|
||||
KeypointsVector();
|
||||
KeypointsVector(const gtsam::gtsfm::KeypointsVector& other);
|
||||
void push_back(const gtsam::gtsfm::Keypoints& keypoints);
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
void clear();
|
||||
gtsam::gtsfm::Keypoints at(const size_t& index) const;
|
||||
};
|
||||
|
||||
gtsam::SfmTrack2dVector tracksFromPairwiseMatches(
|
||||
const gtsam::gtsfm::MatchIndicesMap& matches_dict,
|
||||
const gtsam::gtsfm::KeypointsVector& keypoints_list, bool verbose = false);
|
||||
} // namespace gtsfm
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file TestSfmTrack.cpp
|
||||
* @date October 2022
|
||||
* @author Frank Dellaert
|
||||
* @brief tests for SfmTrack class
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/sfm/SfmTrack.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(SfmTrack2d, defaultConstructor) {
|
||||
SfmTrack2d track;
|
||||
EXPECT_LONGS_EQUAL(0, track.measurements.size());
|
||||
EXPECT_LONGS_EQUAL(0, track.siftIndices.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(SfmTrack2d, measurementConstructor) {
|
||||
SfmTrack2d track({{0, Point2(1, 2)}, {1, Point2(3, 4)}});
|
||||
EXPECT_LONGS_EQUAL(2, track.measurements.size());
|
||||
EXPECT_LONGS_EQUAL(0, track.siftIndices.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(SfmTrack, construction) {
|
||||
SfmTrack track(Point3(1, 2, 3), 4, 5, 6);
|
||||
EXPECT(assert_equal(Point3(1, 2, 3), track.point3()));
|
||||
EXPECT(assert_equal(Point3(4, 5, 6), track.rgb()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -51,7 +51,10 @@ set(ignore
|
|||
gtsam::BinaryMeasurementsUnit3
|
||||
gtsam::BinaryMeasurementsRot3
|
||||
gtsam::DiscreteKey
|
||||
gtsam::KeyPairDoubleMap)
|
||||
gtsam::KeyPairDoubleMap
|
||||
gtsam::gtsfm::MatchIndicesMap
|
||||
gtsam::gtsfm::KeypointsVector
|
||||
gtsam::gtsfm::SfmTrack2dVector)
|
||||
|
||||
set(interface_headers
|
||||
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
|
||||
|
|
@ -148,8 +151,12 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON)
|
|||
gtsam::CameraSetCal3Bundler
|
||||
gtsam::CameraSetCal3Unified
|
||||
gtsam::CameraSetCal3Fisheye
|
||||
gtsam::KeyPairDoubleMap)
|
||||
|
||||
gtsam::KeyPairDoubleMap
|
||||
gtsam::gtsfm::MatchIndicesMap
|
||||
gtsam::gtsfm::KeypointsVector
|
||||
gtsam::gtsfm::SfmTrack2dVector)
|
||||
|
||||
|
||||
pybind_wrap(${GTSAM_PYTHON_UNSTABLE_TARGET} # target
|
||||
${PROJECT_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i # interface_header
|
||||
"gtsam_unstable.cpp" # generated_cpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
# This trick is to allow direct import of sub-modules
|
||||
# without this, we can only do `from gtsam.gtsam.gtsfm import X`
|
||||
# with this trick, we can do `from gtsam.gtsfm import X`
|
||||
from .gtsam.gtsfm import *
|
||||
|
|
@ -15,12 +15,13 @@
|
|||
// #include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(
|
||||
std::vector<gtsam::SfmTrack>);
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(
|
||||
std::vector<gtsam::SfmCamera>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::SfmMeasurement>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::SfmTrack>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::SfmCamera>);
|
||||
PYBIND11_MAKE_OPAQUE(
|
||||
std::vector<gtsam::BinaryMeasurement<gtsam::Unit3>>);
|
||||
PYBIND11_MAKE_OPAQUE(
|
||||
std::vector<gtsam::BinaryMeasurement<gtsam::Rot3>>);
|
||||
PYBIND11_MAKE_OPAQUE(
|
||||
std::vector<gtsam::gtsfm::Keypoints>);
|
||||
PYBIND11_MAKE_OPAQUE(gtsam::gtsfm::MatchIndicesMap);
|
||||
|
|
@ -18,16 +18,11 @@ py::bind_vector<std::vector<gtsam::BinaryMeasurement<gtsam::Unit3> > >(
|
|||
py::bind_vector<std::vector<gtsam::BinaryMeasurement<gtsam::Rot3> > >(
|
||||
m_, "BinaryMeasurementsRot3");
|
||||
py::bind_map<gtsam::KeyPairDoubleMap>(m_, "KeyPairDoubleMap");
|
||||
py::bind_vector<std::vector<gtsam::SfmTrack2d>>(m_, "SfmTrack2dVector");
|
||||
py::bind_vector<std::vector<gtsam::SfmTrack>>(m_, "SfmTracks");
|
||||
py::bind_vector<std::vector<gtsam::SfmCamera>>(m_, "SfmCameras");
|
||||
py::bind_vector<std::vector<std::pair<size_t, gtsam::Point2>>>(
|
||||
m_, "SfmMeasurementVector");
|
||||
|
||||
py::bind_vector<
|
||||
std::vector<gtsam::SfmTrack> >(
|
||||
m_, "SfmTracks");
|
||||
|
||||
py::bind_vector<
|
||||
std::vector<gtsam::SfmCamera> >(
|
||||
m_, "SfmCameras");
|
||||
|
||||
py::bind_vector<
|
||||
std::vector<std::pair<size_t, gtsam::Point2>>>(
|
||||
m_, "SfmMeasurementVector"
|
||||
);
|
||||
py::bind_map<gtsam::gtsfm::MatchIndicesMap>(m_, "MatchIndicesMap");
|
||||
py::bind_vector<std::vector<gtsam::gtsfm::Keypoints>>(m_, "KeypointsVector");
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from __future__ import print_function
|
|||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import gtsam
|
||||
from gtsam import IndexPair
|
||||
from gtsam import DSFMapIndexPair, IndexPair, IndexPairSetAsArray
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
|
|
@ -29,10 +28,10 @@ class TestDSFMap(GtsamTestCase):
|
|||
def key(index_pair) -> Tuple[int, int]:
|
||||
return index_pair.i(), index_pair.j()
|
||||
|
||||
dsf = gtsam.DSFMapIndexPair()
|
||||
pair1 = gtsam.IndexPair(1, 18)
|
||||
dsf = DSFMapIndexPair()
|
||||
pair1 = IndexPair(1, 18)
|
||||
self.assertEqual(key(dsf.find(pair1)), key(pair1))
|
||||
pair2 = gtsam.IndexPair(2, 2)
|
||||
pair2 = IndexPair(2, 2)
|
||||
|
||||
# testing the merge feature of dsf
|
||||
dsf.merge(pair1, pair2)
|
||||
|
|
@ -45,7 +44,7 @@ class TestDSFMap(GtsamTestCase):
|
|||
k'th detected keypoint in image i. For the data below, merging such
|
||||
measurements into feature tracks across frames should create 2 distinct sets.
|
||||
"""
|
||||
dsf = gtsam.DSFMapIndexPair()
|
||||
dsf = DSFMapIndexPair()
|
||||
dsf.merge(IndexPair(0, 1), IndexPair(1, 2))
|
||||
dsf.merge(IndexPair(0, 1), IndexPair(3, 4))
|
||||
dsf.merge(IndexPair(4, 5), IndexPair(6, 8))
|
||||
|
|
@ -56,7 +55,7 @@ class TestDSFMap(GtsamTestCase):
|
|||
for i in sets:
|
||||
set_keys = []
|
||||
s = sets[i]
|
||||
for val in gtsam.IndexPairSetAsArray(s):
|
||||
for val in IndexPairSetAsArray(s):
|
||||
set_keys.append((val.i(), val.j()))
|
||||
merged_sets.add(tuple(set_keys))
|
||||
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""Unit tests for track generation using a Disjoint Set Forest data structure.
|
||||
|
||||
Authors: John Lambert
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import gtsam
|
||||
import numpy as np
|
||||
from gtsam import IndexPair, KeypointsVector, MatchIndicesMap, Point2, SfmMeasurementVector, SfmTrack2d
|
||||
from gtsam.gtsfm import Keypoints
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDsfTrackGenerator(GtsamTestCase):
|
||||
"""Tests for DsfTrackGenerator."""
|
||||
|
||||
def test_track_generation(self) -> None:
|
||||
"""Ensures that DSF generates three tracks from measurements
|
||||
in 3 images (H=200,W=400)."""
|
||||
kps_i0 = Keypoints(np.array([[10.0, 20], [30, 40]]))
|
||||
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
|
||||
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
|
||||
|
||||
keypoints_list = KeypointsVector()
|
||||
keypoints_list.append(kps_i0)
|
||||
keypoints_list.append(kps_i1)
|
||||
keypoints_list.append(kps_i2)
|
||||
|
||||
# For each image pair (i1,i2), we provide a (K,2) matrix
|
||||
# of corresponding image indices (k1,k2).
|
||||
matches_dict = MatchIndicesMap()
|
||||
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
|
||||
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
|
||||
|
||||
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
|
||||
matches_dict,
|
||||
keypoints_list,
|
||||
verbose=False,
|
||||
)
|
||||
assert len(tracks) == 3
|
||||
|
||||
# Verify track 0.
|
||||
track0 = tracks[0]
|
||||
assert track0.numberMeasurements() == 2
|
||||
np.testing.assert_allclose(track0.measurements[0][1], Point2(10, 20))
|
||||
np.testing.assert_allclose(track0.measurements[1][1], Point2(50, 60))
|
||||
assert track0.measurements[0][0] == 0
|
||||
assert track0.measurements[1][0] == 1
|
||||
np.testing.assert_allclose(
|
||||
track0.measurementMatrix(),
|
||||
[
|
||||
[10, 20],
|
||||
[50, 60],
|
||||
],
|
||||
)
|
||||
np.testing.assert_allclose(track0.indexVector(), [0, 1])
|
||||
|
||||
# Verify track 1.
|
||||
track1 = tracks[1]
|
||||
np.testing.assert_allclose(
|
||||
track1.measurementMatrix(),
|
||||
[
|
||||
[30, 40],
|
||||
[70, 80],
|
||||
[130, 140],
|
||||
],
|
||||
)
|
||||
np.testing.assert_allclose(track1.indexVector(), [0, 1, 2])
|
||||
|
||||
# Verify track 2.
|
||||
track2 = tracks[2]
|
||||
np.testing.assert_allclose(
|
||||
track2.measurementMatrix(),
|
||||
[
|
||||
[90, 100],
|
||||
[110, 120],
|
||||
],
|
||||
)
|
||||
np.testing.assert_allclose(track2.indexVector(), [1, 2])
|
||||
|
||||
|
||||
class TestSfmTrack2d(GtsamTestCase):
|
||||
"""Tests for SfmTrack2d."""
|
||||
|
||||
def test_sfm_track_2d_constructor(self) -> None:
|
||||
""" """
|
||||
measurements = SfmMeasurementVector()
|
||||
measurements.append((0, Point2(10, 20)))
|
||||
track = SfmTrack2d(measurements=measurements)
|
||||
track.measurement(0)
|
||||
track.numberMeasurements() == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue