Merge pull request #1948 from borglab/hybrid-timing
commit
169523ecc6
|
@ -168,6 +168,12 @@ foreach(build_type "common" ${GTSAM_CMAKE_CONFIGURATION_TYPES})
|
||||||
append_config_if_not_empty(GTSAM_COMPILE_DEFINITIONS_PUBLIC ${build_type})
|
append_config_if_not_empty(GTSAM_COMPILE_DEFINITIONS_PUBLIC ${build_type})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
# Check if timing is enabled and add appropriate definition flag
|
||||||
|
if(GTSAM_ENABLE_TIMING AND(NOT ${CMAKE_BUILD_TYPE} EQUAL "Timing"))
|
||||||
|
message(STATUS "Enabling timing for non-timing build")
|
||||||
|
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE "ENABLE_TIMING")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Linker flags:
|
# Linker flags:
|
||||||
set(GTSAM_CMAKE_SHARED_LINKER_FLAGS_TIMING "${CMAKE_SHARED_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
|
set(GTSAM_CMAKE_SHARED_LINKER_FLAGS_TIMING "${CMAKE_SHARED_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
|
||||||
set(GTSAM_CMAKE_MODULE_LINKER_FLAGS_TIMING "${CMAKE_MODULE_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
|
set(GTSAM_CMAKE_MODULE_LINKER_FLAGS_TIMING "${CMAKE_MODULE_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
|
||||||
|
|
|
@ -33,6 +33,8 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu
|
||||||
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
|
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
|
||||||
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
|
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
|
||||||
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
|
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
|
||||||
|
option(GTSAM_ENABLE_TIMING "Enable the timing tools (gttic/gttoc)" OFF)
|
||||||
|
option(GTSAM_HYBRID_TIMING "Enable the timing of hybrid factor graph machinery" OFF)
|
||||||
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
|
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
|
||||||
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
|
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
|
||||||
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
|
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
|
||||||
|
|
|
@ -91,6 +91,7 @@ print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory San
|
||||||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
|
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
|
||||||
|
print_enabled_config(${GTSAM_ENABLE_TIMING} "Enable timing machinery")
|
||||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
|
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
|
||||||
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
||||||
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
||||||
|
|
|
@ -31,7 +31,9 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
using ChildOrder = FastMap<size_t, std::shared_ptr<TimingOutline>>;
|
||||||
|
|
||||||
// a static shared_ptr to TimingOutline with nullptr as the pointer
|
// a static shared_ptr to TimingOutline with nullptr as the pointer
|
||||||
const static std::shared_ptr<TimingOutline> nullTimingOutline;
|
const static std::shared_ptr<TimingOutline> nullTimingOutline;
|
||||||
|
|
||||||
|
@ -91,7 +93,6 @@ void TimingOutline::print(const std::string& outline) const {
|
||||||
<< n_ << " times, " << wall() << " wall, " << secs() << " children, min: "
|
<< n_ << " times, " << wall() << " wall, " << secs() << " children, min: "
|
||||||
<< min() << " max: " << max() << ")\n";
|
<< min() << " max: " << max() << ")\n";
|
||||||
// Order children
|
// Order children
|
||||||
typedef FastMap<size_t, std::shared_ptr<TimingOutline> > ChildOrder;
|
|
||||||
ChildOrder childOrder;
|
ChildOrder childOrder;
|
||||||
for(const ChildMap::value_type& child: children_) {
|
for(const ChildMap::value_type& child: children_) {
|
||||||
childOrder[child.second->myOrder_] = child.second;
|
childOrder[child.second->myOrder_] = child.second;
|
||||||
|
@ -106,6 +107,54 @@ void TimingOutline::print(const std::string& outline) const {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void TimingOutline::printCsvHeader(bool addLineBreak) const {
|
||||||
|
#ifdef GTSAM_USE_BOOST_FEATURES
|
||||||
|
// Order is (CPU time, number of times, wall time, time + children in seconds,
|
||||||
|
// min time, max time)
|
||||||
|
std::cout << label_ + " cpu time (s)" << "," << label_ + " #calls" << ","
|
||||||
|
<< label_ + " wall time(s)" << "," << label_ + " subtree time (s)"
|
||||||
|
<< "," << label_ + " min time (s)" << "," << label_ + "max time(s)"
|
||||||
|
<< ",";
|
||||||
|
// Order children
|
||||||
|
ChildOrder childOrder;
|
||||||
|
for (const ChildMap::value_type& child : children_) {
|
||||||
|
childOrder[child.second->myOrder_] = child.second;
|
||||||
|
}
|
||||||
|
// Print children
|
||||||
|
for (const ChildOrder::value_type& order_child : childOrder) {
|
||||||
|
order_child.second->printCsvHeader();
|
||||||
|
}
|
||||||
|
if (addLineBreak) {
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
std::cout.flush();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void TimingOutline::printCsv(bool addLineBreak) const {
|
||||||
|
#ifdef GTSAM_USE_BOOST_FEATURES
|
||||||
|
// Order is (CPU time, number of times, wall time, time + children in seconds,
|
||||||
|
// min time, max time)
|
||||||
|
std::cout << self() << "," << n_ << "," << wall() << "," << secs() << ","
|
||||||
|
<< min() << "," << max() << ",";
|
||||||
|
// Order children
|
||||||
|
ChildOrder childOrder;
|
||||||
|
for (const ChildMap::value_type& child : children_) {
|
||||||
|
childOrder[child.second->myOrder_] = child.second;
|
||||||
|
}
|
||||||
|
// Print children
|
||||||
|
for (const ChildOrder::value_type& order_child : childOrder) {
|
||||||
|
order_child.second->printCsv(false);
|
||||||
|
}
|
||||||
|
if (addLineBreak) {
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
std::cout.flush();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void TimingOutline::print2(const std::string& outline,
|
void TimingOutline::print2(const std::string& outline,
|
||||||
const double parentTotal) const {
|
const double parentTotal) const {
|
||||||
#if GTSAM_USE_BOOST_FEATURES
|
#if GTSAM_USE_BOOST_FEATURES
|
||||||
|
|
|
@ -199,6 +199,29 @@ namespace gtsam {
|
||||||
#endif
|
#endif
|
||||||
GTSAM_EXPORT void print(const std::string& outline = "") const;
|
GTSAM_EXPORT void print(const std::string& outline = "") const;
|
||||||
GTSAM_EXPORT void print2(const std::string& outline = "", const double parentTotal = -1.0) const;
|
GTSAM_EXPORT void print2(const std::string& outline = "", const double parentTotal = -1.0) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Print the CSV header.
|
||||||
|
* Order is
|
||||||
|
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||||
|
* time, max time)
|
||||||
|
*
|
||||||
|
* @param addLineBreak Flag indicating if a line break should be added at
|
||||||
|
* the end. Only used at the top-leve.
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Print the times recursively from parent to child in CSV format.
|
||||||
|
* For each timing node, the output is
|
||||||
|
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||||
|
* time, max time)
|
||||||
|
*
|
||||||
|
* @param addLineBreak Flag indicating if a line break should be added at
|
||||||
|
* the end. Only used at the top-leve.
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;
|
||||||
|
|
||||||
GTSAM_EXPORT const std::shared_ptr<TimingOutline>&
|
GTSAM_EXPORT const std::shared_ptr<TimingOutline>&
|
||||||
child(size_t child, const std::string& label, const std::weak_ptr<TimingOutline>& thisPtr);
|
child(size_t child, const std::string& label, const std::weak_ptr<TimingOutline>& thisPtr);
|
||||||
GTSAM_EXPORT void tic();
|
GTSAM_EXPORT void tic();
|
||||||
|
@ -268,6 +291,14 @@ inline void tictoc_finishedIteration_() {
|
||||||
inline void tictoc_print_() {
|
inline void tictoc_print_() {
|
||||||
::gtsam::internal::gTimingRoot->print(); }
|
::gtsam::internal::gTimingRoot->print(); }
|
||||||
|
|
||||||
|
// print timing in CSV format
|
||||||
|
inline void tictoc_printCsv_(bool displayHeader = false) {
|
||||||
|
if (displayHeader) {
|
||||||
|
::gtsam::internal::gTimingRoot->printCsvHeader(true);
|
||||||
|
}
|
||||||
|
::gtsam::internal::gTimingRoot->printCsv(true);
|
||||||
|
}
|
||||||
|
|
||||||
// print mean and standard deviation
|
// print mean and standard deviation
|
||||||
inline void tictoc_print2_() {
|
inline void tictoc_print2_() {
|
||||||
::gtsam::internal::gTimingRoot->print2(); }
|
::gtsam::internal::gTimingRoot->print2(); }
|
||||||
|
|
|
@ -42,6 +42,9 @@
|
||||||
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
|
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
|
||||||
#cmakedefine GTSAM_DT_MERGING
|
#cmakedefine GTSAM_DT_MERGING
|
||||||
|
|
||||||
|
// Whether to enable timing in hybrid factor graph machinery
|
||||||
|
#cmakedefine01 GTSAM_HYBRID_TIMING
|
||||||
|
|
||||||
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
|
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
|
||||||
#cmakedefine GTSAM_USE_TBB
|
#cmakedefine GTSAM_USE_TBB
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,9 @@ namespace gtsam {
|
||||||
|
|
||||||
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||||
|
|
||||||
|
/// Constructor which accepts root pointer
|
||||||
|
AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {}
|
||||||
|
|
||||||
// Explicitly non-explicit constructor
|
// Explicitly non-explicit constructor
|
||||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||||
|
|
||||||
|
|
|
@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional DiscreteConditional::operator*(
|
DiscreteConditional DiscreteConditional::operator*(
|
||||||
const DiscreteConditional& other) const {
|
const DiscreteConditional& other) const {
|
||||||
|
// If the root is a nullptr, we have a TableDistribution
|
||||||
|
// TODO(Varun) Revisit this hack after RSS2025 submission
|
||||||
|
if (!other.root_) {
|
||||||
|
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
|
||||||
|
return dc * (*this);
|
||||||
|
}
|
||||||
|
|
||||||
// Take union of frontal keys
|
// Take union of frontal keys
|
||||||
std::set<Key> newFrontals;
|
std::set<Key> newFrontals;
|
||||||
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||||
|
@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
|
||||||
return this->operator()(x.discrete());
|
return this->operator()(x.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteFactor::shared_ptr DiscreteConditional::max(
|
||||||
|
const Ordering& keys) const {
|
||||||
|
return BaseFactor::max(keys);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void DiscreteConditional::prune(size_t maxNrAssignments) {
|
||||||
|
// Get as DiscreteConditional so the probabilities are normalized
|
||||||
|
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
|
||||||
|
this->root_ = pruned.root_;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteConditional::negLogConstant() const { return 0.0; }
|
double DiscreteConditional::negLogConstant() const { return 0.0; }
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample(const DiscreteValues& parentsValues) const;
|
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
/// Single parent version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value) const;
|
||||||
|
@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
*/
|
*/
|
||||||
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create new factor by maximizing over all
|
||||||
|
* values with the same separator.
|
||||||
|
*
|
||||||
|
* @param keys The keys to sum over.
|
||||||
|
* @return DiscreteFactor::shared_ptr
|
||||||
|
*/
|
||||||
|
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
*/
|
*/
|
||||||
double negLogConstant() const override;
|
double negLogConstant() const override;
|
||||||
|
|
||||||
|
/// Prune the conditional
|
||||||
|
virtual void prune(size_t maxNrAssignments);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -118,17 +118,11 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/**
|
/* ************************************************************************ */
|
||||||
* @brief Multiply all the `factors`.
|
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
|
||||||
*
|
|
||||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
|
||||||
* @return DiscreteFactor::shared_ptr
|
|
||||||
*/
|
|
||||||
static DiscreteFactor::shared_ptr DiscreteProduct(
|
|
||||||
const DiscreteFactorGraph& factors) {
|
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DiscreteFactor::shared_ptr product = factors.product();
|
DiscreteFactor::shared_ptr product = this->product();
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
|
@ -145,7 +139,7 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
DiscreteFactor::shared_ptr product = factors.scaledProduct();
|
||||||
|
|
||||||
// max out frontals, this is the factor on the separator
|
// max out frontals, this is the factor on the separator
|
||||||
gttic(max);
|
gttic(max);
|
||||||
|
@ -223,7 +217,7 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
DiscreteFactor::shared_ptr product = factors.scaledProduct();
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
gttic(sum);
|
gttic(sum);
|
||||||
|
|
|
@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
/** return product of all factors as a single factor */
|
/** return product of all factors as a single factor */
|
||||||
DiscreteFactor::shared_ptr product() const;
|
DiscreteFactor::shared_ptr product() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return product of all `factors` as a single factor,
|
||||||
|
* which is scaled by the max value to prevent underflow
|
||||||
|
*
|
||||||
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
|
* @return DiscreteFactor::shared_ptr
|
||||||
|
*/
|
||||||
|
DiscreteFactor::shared_ptr scaledProduct() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluates the factor graph given values, returns the joint probability of
|
* Evaluates the factor graph given values, returns the joint probability of
|
||||||
* the factor graph given specific instantiation of values
|
* the factor graph given specific instantiation of values
|
||||||
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file TableDistribution.cpp
|
||||||
|
* @date Dec 22, 2024
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
|
#include <gtsam/discrete/Ring.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <random>
|
||||||
|
#include <set>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using std::pair;
|
||||||
|
using std::stringstream;
|
||||||
|
using std::vector;
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/// Normalize sparse_table
|
||||||
|
static Eigen::SparseVector<double> normalizeSparseTable(
|
||||||
|
const Eigen::SparseVector<double>& sparse_table) {
|
||||||
|
return sparse_table / sparse_table.sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
TableDistribution::TableDistribution(const TableFactor& f)
|
||||||
|
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
|
||||||
|
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
|
||||||
|
f.sum(f.keys().size())))) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
TableDistribution::TableDistribution(const DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& potentials)
|
||||||
|
: BaseConditional(keys.size(), keys, ADT(nullptr)),
|
||||||
|
table_(TableFactor(
|
||||||
|
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
TableDistribution::TableDistribution(const DiscreteKeys& keys,
|
||||||
|
const std::string& potentials)
|
||||||
|
: BaseConditional(keys.size(), keys, ADT(nullptr)),
|
||||||
|
table_(TableFactor(
|
||||||
|
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void TableDistribution::print(const string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
cout << s << " P( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
table_.print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
|
||||||
|
auto dtc = dynamic_cast<const TableDistribution*>(&other);
|
||||||
|
if (!dtc) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
const DiscreteConditional& f(
|
||||||
|
static_cast<const DiscreteConditional&>(other));
|
||||||
|
return table_.equals(dtc->table_, tol) &&
|
||||||
|
DiscreteConditional::BaseConditional::equals(f, tol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const {
|
||||||
|
return table_.sum(nrFrontals);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const {
|
||||||
|
return table_.sum(keys);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const {
|
||||||
|
return table_.max(nrFrontals);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
|
||||||
|
return table_.max(keys);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DiscreteFactor::shared_ptr TableDistribution::operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const {
|
||||||
|
return table_ / f;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues TableDistribution::argmax() const {
|
||||||
|
uint64_t maxIdx = 0;
|
||||||
|
double maxValue = 0.0;
|
||||||
|
|
||||||
|
Eigen::SparseVector<double> sparseTable = table_.sparseTable();
|
||||||
|
|
||||||
|
for (SparseIt it(sparseTable); it; ++it) {
|
||||||
|
if (it.value() > maxValue) {
|
||||||
|
maxIdx = it.index();
|
||||||
|
maxValue = it.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return table_.findAssignments(maxIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
void TableDistribution::prune(size_t maxNrAssignments) {
|
||||||
|
table_ = table_.prune(maxNrAssignments);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||||
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
|
DiscreteKeys parentsKeys;
|
||||||
|
for (auto&& [key, _] : parentsValues) {
|
||||||
|
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the correct conditional distribution: P(F|S=parentsValues)
|
||||||
|
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
|
||||||
|
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
if (nrFrontals() != 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"TableDistribution::sample can only be called on single variable "
|
||||||
|
"conditionals");
|
||||||
|
}
|
||||||
|
Key key = firstFrontalKey();
|
||||||
|
size_t nj = cardinality(key);
|
||||||
|
vector<double> p(nj);
|
||||||
|
DiscreteValues frontals;
|
||||||
|
for (size_t value = 0; value < nj; value++) {
|
||||||
|
frontals[key] = value;
|
||||||
|
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
if (p[value] == 1.0) {
|
||||||
|
return value; // shortcut exit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||||
|
return distribution(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -0,0 +1,177 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file TableDistribution.h
|
||||||
|
* @date Dec 22, 2024
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Distribution which uses a SparseVector as the internal
|
||||||
|
* representation, similar to the TableFactor.
|
||||||
|
*
|
||||||
|
* This is primarily used in the case when we have a clique in the BayesTree
|
||||||
|
* which consists of all the discrete variables, e.g. in hybrid elimination.
|
||||||
|
*
|
||||||
|
* @ingroup discrete
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
|
private:
|
||||||
|
TableFactor table_;
|
||||||
|
|
||||||
|
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// typedefs needed to play nice with gtsam
|
||||||
|
typedef TableDistribution This; ///< Typedef to this class
|
||||||
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
|
typedef DiscreteConditional
|
||||||
|
BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor needed for serialization.
|
||||||
|
TableDistribution() {}
|
||||||
|
|
||||||
|
/// Construct from TableFactor.
|
||||||
|
TableDistribution(const TableFactor& f);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from DiscreteKeys and std::vector.
|
||||||
|
*/
|
||||||
|
TableDistribution(const DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& potentials);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from single DiscreteKey and std::vector.
|
||||||
|
*/
|
||||||
|
TableDistribution(const DiscreteKey& key,
|
||||||
|
const std::vector<double>& potentials)
|
||||||
|
: TableDistribution(DiscreteKeys(key), potentials) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from DiscreteKey and std::string.
|
||||||
|
*/
|
||||||
|
TableDistribution(const DiscreteKeys& keys, const std::string& potentials);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from single DiscreteKey and std::string.
|
||||||
|
*/
|
||||||
|
TableDistribution(const DiscreteKey& key, const std::string& potentials)
|
||||||
|
: TableDistribution(DiscreteKeys(key), potentials) {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Table Distribution: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/// GTSAM-style equals
|
||||||
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Return the underlying TableFactor
|
||||||
|
TableFactor table() const { return table_; }
|
||||||
|
|
||||||
|
using BaseConditional::evaluate; // HybridValues version
|
||||||
|
|
||||||
|
/// Evaluate the conditional given the values.
|
||||||
|
virtual double evaluate(const Assignment<Key>& values) const override {
|
||||||
|
return table_.evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
|
||||||
|
|
||||||
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||||
|
|
||||||
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
|
DiscreteFactor::shared_ptr operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return assignment that maximizes value.
|
||||||
|
*
|
||||||
|
* @return maximizing assignment for the variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues argmax() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* sample
|
||||||
|
* @param parentsValues Known values of the parents
|
||||||
|
* @return sample from conditional
|
||||||
|
*/
|
||||||
|
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Advanced Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Prune the conditional
|
||||||
|
virtual void prune(size_t maxNrAssignments) override;
|
||||||
|
|
||||||
|
/// Get a DecisionTreeFactor representation.
|
||||||
|
DecisionTreeFactor toDecisionTreeFactor() const override {
|
||||||
|
return table_.toDecisionTreeFactor();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the number of non-zero values.
|
||||||
|
uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(table_);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
// TableDistribution
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<TableDistribution> : public Testable<TableDistribution> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -87,6 +87,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
/**
|
/**
|
||||||
* Convert probability table given as doubles to SparseVector.
|
* Convert probability table given as doubles to SparseVector.
|
||||||
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||||
|
@ -98,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
||||||
const std::string& table);
|
const std::string& table);
|
||||||
|
|
||||||
public:
|
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef TableFactor This;
|
typedef TableFactor This;
|
||||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||||
|
@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
||||||
/// Create a TableFactor that is a subset of this TableFactor
|
/// Create a TableFactor that is a subset of this TableFactor
|
||||||
TableFactor choose(const DiscreteValues assignments,
|
TableFactor choose(const DiscreteValues parentAssignments,
|
||||||
DiscreteKeys parent_keys) const;
|
DiscreteKeys parent_keys) const;
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
|
|
@ -168,6 +168,43 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
virtual class TableFactor : gtsam::DiscreteFactor {
|
||||||
|
TableFactor();
|
||||||
|
TableFactor(const gtsam::DiscreteKeys& keys,
|
||||||
|
const gtsam::TableFactor& potentials);
|
||||||
|
TableFactor(const gtsam::DiscreteKeys& keys, std::vector<double>& table);
|
||||||
|
TableFactor(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
|
TableFactor(const gtsam::DiscreteKeys& keys,
|
||||||
|
const gtsam::DecisionTreeFactor& dtf);
|
||||||
|
TableFactor(const gtsam::DecisionTreeFactor& dtf);
|
||||||
|
|
||||||
|
void print(string s = "TableFactor\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
double error(const gtsam::DiscreteValues& values) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
|
virtual class TableDistribution : gtsam::DiscreteConditional {
|
||||||
|
TableDistribution();
|
||||||
|
TableDistribution(const gtsam::TableFactor& f);
|
||||||
|
TableDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||||
|
TableDistribution(const gtsam::DiscreteKeys& keys, std::vector<double> spec);
|
||||||
|
TableDistribution(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
|
TableDistribution(const gtsam::DiscreteKey& key, string spec);
|
||||||
|
|
||||||
|
void print(string s = "Table Distribution\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
gtsam::TableFactor table() const;
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
size_t nrValues() const;
|
||||||
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
class DiscreteBayesNet {
|
class DiscreteBayesNet {
|
||||||
DiscreteBayesNet();
|
DiscreteBayesNet();
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
@ -55,12 +56,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
joint = joint * (*conditional);
|
joint = joint * (*conditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune the joint. NOTE: again, possibly quite expensive.
|
// Create the result starting with the pruned joint.
|
||||||
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
|
|
||||||
|
|
||||||
// Create a the result starting with the pruned joint.
|
|
||||||
HybridBayesNet result;
|
HybridBayesNet result;
|
||||||
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
|
result.emplace_shared<DiscreteConditional>(joint);
|
||||||
|
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||||
|
result.back()->asDiscrete()->prune(maxNrLeaves);
|
||||||
|
|
||||||
|
// Get pruned discrete probabilities so
|
||||||
|
// we can prune HybridGaussianConditionals.
|
||||||
|
DiscreteConditional pruned = *result.back()->asDiscrete();
|
||||||
|
|
||||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
|
@ -126,7 +130,14 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
discrete_fg.push_back(conditional->asDiscrete());
|
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
|
||||||
|
// The number of keys should be small so should not
|
||||||
|
// be expensive to convert to DiscreteConditional.
|
||||||
|
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
|
||||||
|
dtc->toDecisionTreeFactor()));
|
||||||
|
} else {
|
||||||
|
discrete_fg.push_back(conditional->asDiscrete());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/base/treeTraversal-inst.h>
|
#include <gtsam/base/treeTraversal-inst.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
@ -41,6 +42,22 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
||||||
return Base::equals(other, tol);
|
return Base::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteValues HybridBayesTree::discreteMaxProduct(
|
||||||
|
const DiscreteFactorGraph& dfg) const {
|
||||||
|
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||||
|
|
||||||
|
// Check type of product, and get as TableFactor for efficiency.
|
||||||
|
TableFactor p;
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||||
|
p = *tf;
|
||||||
|
} else {
|
||||||
|
p = TableFactor(product->toDecisionTreeFactor());
|
||||||
|
}
|
||||||
|
DiscreteValues assignment = TableDistribution(p).argmax();
|
||||||
|
return assignment;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesTree::optimize() const {
|
HybridValues HybridBayesTree::optimize() const {
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg;
|
||||||
|
@ -52,8 +69,9 @@ HybridValues HybridBayesTree::optimize() const {
|
||||||
|
|
||||||
// The root should be discrete only, we compute the MPE
|
// The root should be discrete only, we compute the MPE
|
||||||
if (root_conditional->isDiscrete()) {
|
if (root_conditional->isDiscrete()) {
|
||||||
discrete_fg.push_back(root_conditional->asDiscrete());
|
auto discrete = root_conditional->asDiscrete<TableDistribution>();
|
||||||
mpe = discrete_fg.optimize();
|
discrete_fg.push_back(discrete);
|
||||||
|
mpe = discreteMaxProduct(discrete_fg);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybridBayesTree root is not discrete-only. Please check elimination "
|
"HybridBayesTree root is not discrete-only. Please check elimination "
|
||||||
|
@ -179,16 +197,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
auto prunedDiscreteProbs =
|
||||||
|
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
|
||||||
|
|
||||||
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
// Imperative pruning
|
||||||
discreteProbs->root_ = prunedDiscreteProbs.root_;
|
prunedDiscreteProbs->prune(maxNrLeaves);
|
||||||
|
|
||||||
/// Helper struct for pruning the hybrid bayes tree.
|
/// Helper struct for pruning the hybrid bayes tree.
|
||||||
struct HybridPrunerData {
|
struct HybridPrunerData {
|
||||||
/// The discrete decision tree after pruning.
|
/// The discrete decision tree after pruning.
|
||||||
DecisionTreeFactor prunedDiscreteProbs;
|
DiscreteConditional::shared_ptr prunedDiscreteProbs;
|
||||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
|
HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs,
|
||||||
const HybridBayesTree::sharedNode& parentClique)
|
const HybridBayesTree::sharedNode& parentClique)
|
||||||
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
||||||
|
|
||||||
|
@ -213,7 +232,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
if (!hybridGaussianCond->pruned()) {
|
if (!hybridGaussianCond->pruned()) {
|
||||||
// Imperative
|
// Imperative
|
||||||
clique->conditional() = std::make_shared<HybridConditional>(
|
clique->conditional() = std::make_shared<HybridConditional>(
|
||||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return parentData;
|
return parentData;
|
||||||
|
|
|
@ -115,6 +115,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// Helper method to compute the max product assignment
|
||||||
|
/// given a DiscreteFactorGraph
|
||||||
|
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph& dfg) const;
|
||||||
|
|
||||||
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
@ -166,12 +166,13 @@ class GTSAM_EXPORT HybridConditional
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return conditional as a DiscreteConditional
|
* @brief Return conditional as a DiscreteConditional or specified type T.
|
||||||
* @return nullptr if not a DiscreteConditional
|
* @return nullptr if not a DiscreteConditional
|
||||||
* @return DiscreteConditional::shared_ptr
|
* @return DiscreteConditional::shared_ptr
|
||||||
*/
|
*/
|
||||||
DiscreteConditional::shared_ptr asDiscrete() const {
|
template <typename T = DiscreteConditional>
|
||||||
return std::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
typename T::shared_ptr asDiscrete() const {
|
||||||
|
return std::dynamic_pointer_cast<T>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
|
|
|
@ -304,7 +304,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
const DecisionTreeFactor &discreteProbs) const {
|
const DiscreteConditional &discreteProbs) const {
|
||||||
// Find keys in discreteProbs.keys() but not in this->keys():
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||||
std::set<Key> theirs(discreteProbs.keys().begin(),
|
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||||
*/
|
*/
|
||||||
HybridGaussianConditional::shared_ptr prune(
|
HybridGaussianConditional::shared_ptr prune(
|
||||||
const DecisionTreeFactor &discreteProbs) const;
|
const DiscreteConditional &discreteProbs) const;
|
||||||
|
|
||||||
/// Return true if the conditional has already been pruned.
|
/// Return true if the conditional has already been pruned.
|
||||||
bool pruned() const { return pruned_; }
|
bool pruned() const { return pruned_; }
|
||||||
|
|
|
@ -20,12 +20,13 @@
|
||||||
|
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
@ -241,29 +242,29 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
/**
|
/**
|
||||||
* @brief Take negative log-values, shift them so that the minimum value is 0,
|
* @brief Take negative log-values, shift them so that the minimum value is 0,
|
||||||
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
* and then exponentiate to create a TableFactor (not normalized yet!).
|
||||||
*
|
*
|
||||||
* @param errors DecisionTree of (unnormalized) errors.
|
* @param errors DecisionTree of (unnormalized) errors.
|
||||||
* @return DecisionTreeFactor::shared_ptr
|
* @return TableFactor::shared_ptr
|
||||||
*/
|
*/
|
||||||
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const AlgebraicDecisionTree<Key> &errors) {
|
const AlgebraicDecisionTree<Key> &errors) {
|
||||||
double min_log = errors.min();
|
double min_log = errors.min();
|
||||||
AlgebraicDecisionTree<Key> potentials(
|
AlgebraicDecisionTree<Key> potentials(
|
||||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static DiscreteFactorGraph CollectDiscreteFactors(
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
const HybridGaussianFactorGraph &factors) {
|
||||||
const Ordering &frontalKeys) {
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
dfg.push_back(df);
|
dfg.push_back(df);
|
||||||
|
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute a discrete factor from the remaining error.
|
// In this case, compute a discrete factor from the remaining error.
|
||||||
|
@ -282,16 +283,73 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
auto dc = hc->asDiscrete();
|
auto dc = hc->asDiscrete();
|
||||||
if (!dc) throwRuntimeError("discreteElimination", dc);
|
if (!dc) throwRuntimeError("discreteElimination", dc);
|
||||||
dfg.push_back(dc);
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(ConvertConditionalToTableFactor);
|
||||||
|
#endif
|
||||||
|
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
|
||||||
|
/// Get the underlying TableFactor
|
||||||
|
dfg.push_back(dtc->table());
|
||||||
|
} else {
|
||||||
|
// Convert DiscreteConditional to TableFactor
|
||||||
|
auto tdc = std::make_shared<TableFactor>(*dc);
|
||||||
|
dfg.push_back(tdc);
|
||||||
|
}
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(ConvertConditionalToTableFactor);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
throwRuntimeError("discreteElimination", f);
|
throwRuntimeError("discreteElimination", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
return dfg;
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
}
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
/* ************************************************************************ */
|
||||||
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
const Ordering &frontalKeys) {
|
||||||
|
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscrete);
|
||||||
|
#endif
|
||||||
|
// Check if separator is empty.
|
||||||
|
// This is the same as checking if the number of frontal variables
|
||||||
|
// is the same as the number of variables in the DiscreteFactorGraph.
|
||||||
|
// If the separator is empty, we have a clique of all the discrete variables
|
||||||
|
// so we can use the TableFactor for efficiency.
|
||||||
|
if (frontalKeys.size() == dfg.keys().size()) {
|
||||||
|
// Get product factor
|
||||||
|
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
// Check type of product, and get as TableFactor for efficiency.
|
||||||
|
TableFactor p;
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||||
|
p = *tf;
|
||||||
|
} else {
|
||||||
|
p = TableFactor(product->toDecisionTreeFactor());
|
||||||
|
}
|
||||||
|
auto conditional = std::make_shared<TableDistribution>(p);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||||
|
|
||||||
|
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Perform sum-product.
|
||||||
|
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||||
|
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||||
|
}
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscrete);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -319,8 +377,19 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteBoundaryErrors);
|
||||||
|
#endif
|
||||||
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
|
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
|
||||||
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteBoundaryErrors);
|
||||||
|
gttic_(DiscreteBoundaryResult);
|
||||||
|
#endif
|
||||||
|
auto result = DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteBoundaryResult);
|
||||||
|
#endif
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -360,12 +429,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
// the discrete separator will be *all* the discrete keys.
|
// the discrete separator will be *all* the discrete keys.
|
||||||
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(HybridCollectProductFactor);
|
||||||
|
#endif
|
||||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
// decision tree indexed by all discrete keys involved. Just like any hybrid
|
// decision tree indexed by all discrete keys involved. Just like any hybrid
|
||||||
// factor, every assignment also has a scalar error, in this case the sum of
|
// factor, every assignment also has a scalar error, in this case the sum of
|
||||||
// all errors in the graph. This error is assignment-specific and accounts for
|
// all errors in the graph. This error is assignment-specific and accounts for
|
||||||
// any difference in noise models used.
|
// any difference in noise models used.
|
||||||
HybridGaussianProductFactor productFactor = collectProductFactor();
|
HybridGaussianProductFactor productFactor = collectProductFactor();
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(HybridCollectProductFactor);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Check if a factor is null
|
// Check if a factor is null
|
||||||
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
|
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
|
||||||
|
@ -393,8 +468,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
return {conditional, conditional->negLogConstant(), factor, scalar};
|
return {conditional, conditional->negLogConstant(), factor, scalar};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(HybridEliminate);
|
||||||
|
#endif
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
const ResultTree eliminationResults(productFactor, eliminate);
|
const ResultTree eliminationResults(productFactor, eliminate);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(HybridEliminate);
|
||||||
|
#endif
|
||||||
|
|
||||||
// If there are no more continuous parents we create a DiscreteFactor with the
|
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||||
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
||||||
|
|
|
@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal(
|
||||||
elimination_ordering, function, std::cref(index));
|
elimination_ordering, function, std::cref(index));
|
||||||
|
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(HybridBayesTreePrune);
|
||||||
|
#endif
|
||||||
bayesTree->prune(*maxNrLeaves);
|
bayesTree->prune(*maxNrLeaves);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(HybridBayesTreePrune);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-add into Bayes tree data structures
|
// Re-add into Bayes tree data structures
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
||||||
|
@ -65,7 +66,14 @@ void HybridNonlinearISAM::reorderRelinearize() {
|
||||||
// Obtain the new linearization point
|
// Obtain the new linearization point
|
||||||
const Values newLinPoint = estimate();
|
const Values newLinPoint = estimate();
|
||||||
|
|
||||||
auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete());
|
DiscreteConditional::shared_ptr discreteProbabilities;
|
||||||
|
|
||||||
|
auto discreteRoot = isam_.roots().at(0)->conditional();
|
||||||
|
if (discreteRoot->asDiscrete<TableDistribution>()) {
|
||||||
|
discreteProbabilities = discreteRoot->asDiscrete<TableDistribution>();
|
||||||
|
} else {
|
||||||
|
discreteProbabilities = discreteRoot->asDiscrete();
|
||||||
|
}
|
||||||
|
|
||||||
isam_.clear();
|
isam_.clear();
|
||||||
|
|
||||||
|
@ -73,7 +81,7 @@ void HybridNonlinearISAM::reorderRelinearize() {
|
||||||
HybridNonlinearFactorGraph pruned_factors;
|
HybridNonlinearFactorGraph pruned_factors;
|
||||||
for (auto&& factor : factors_) {
|
for (auto&& factor : factors_) {
|
||||||
if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
||||||
pruned_factors.push_back(nf->prune(discreteProbs));
|
pruned_factors.push_back(nf->prune(*discreteProbabilities));
|
||||||
} else {
|
} else {
|
||||||
pruned_factors.push_back(factor);
|
pruned_factors.push_back(factor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
|
@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
double midway = mu1 - mu0;
|
double midway = mu1 - mu0;
|
||||||
auto eliminationResult =
|
auto eliminationResult =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
||||||
auto pMid = *eliminationResult->at(0)->asDiscrete();
|
auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
|
EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid));
|
||||||
|
|
||||||
// Everywhere else, the result should be a sigmoid.
|
// Everywhere else, the result should be a sigmoid.
|
||||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||||
|
@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
// Workflow 1: convert HBN to HFG and solve
|
// Workflow 1: convert HBN to HFG and solve
|
||||||
auto eliminationResult1 =
|
auto eliminationResult1 =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
auto posterior1 =
|
||||||
|
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
||||||
hfg1.push_back(mixing);
|
hfg1.push_back(mixing);
|
||||||
auto eliminationResult2 = hfg1.eliminateSequential();
|
auto eliminationResult2 = hfg1.eliminateSequential();
|
||||||
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
auto posterior2 =
|
||||||
|
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -133,13 +136,13 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
// Eliminate the graph!
|
// Eliminate the graph!
|
||||||
auto eliminationResultMax = gfg.eliminateSequential();
|
auto eliminationResultMax = gfg.eliminateSequential();
|
||||||
|
|
||||||
// Equality of posteriors asserts that the elimination is correct (same ratios
|
// Equality of posteriors asserts that the elimination is correct
|
||||||
// for all modes)
|
// (same ratios for all modes)
|
||||||
EXPECT(assert_equal(expectedDiscretePosterior,
|
EXPECT(assert_equal(expectedDiscretePosterior,
|
||||||
eliminationResultMax->discretePosterior(vv)));
|
eliminationResultMax->discretePosterior(vv)));
|
||||||
|
|
||||||
auto pMax = *eliminationResultMax->at(0)->asDiscrete();
|
auto pMax = *eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
|
EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4));
|
||||||
|
|
||||||
// Everywhere else, the result should be a bell curve like function.
|
// Everywhere else, the result should be a bell curve like function.
|
||||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||||
|
@ -149,7 +152,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
// Workflow 1: convert HBN to HFG and solve
|
// Workflow 1: convert HBN to HFG and solve
|
||||||
auto eliminationResult1 =
|
auto eliminationResult1 =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
auto posterior1 =
|
||||||
|
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
@ -158,10 +162,12 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
||||||
hfg.push_back(mixing);
|
hfg.push_back(mixing);
|
||||||
auto eliminationResult2 = hfg.eliminateSequential();
|
auto eliminationResult2 = hfg.eliminateSequential();
|
||||||
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
auto posterior2 =
|
||||||
|
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
@ -454,7 +455,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t maxNrLeaves = 3;
|
size_t maxNrLeaves = 3;
|
||||||
auto prunedDecisionTree = joint.prune(maxNrLeaves);
|
DiscreteConditional prunedDecisionTree(joint);
|
||||||
|
prunedDecisionTree.prune(maxNrLeaves);
|
||||||
|
|
||||||
#ifdef GTSAM_DT_MERGING
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/geometry/Pose2.h>
|
#include <gtsam/geometry/Pose2.h>
|
||||||
#include <gtsam/geometry/Pose3.h>
|
#include <gtsam/geometry/Pose3.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
@ -464,14 +465,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
||||||
|
|
||||||
// Create expected discrete conditional on m0.
|
// Create expected discrete conditional on m0.
|
||||||
DiscreteKey m(M(0), 2);
|
DiscreteKey m(M(0), 2);
|
||||||
DiscreteConditional expected(m % "0.51341712/1"); // regression
|
TableDistribution expected(m, "0.51341712 1"); // regression
|
||||||
|
|
||||||
// Eliminate into BN using one ordering
|
// Eliminate into BN using one ordering
|
||||||
const Ordering ordering1{X(0), X(1), M(0)};
|
const Ordering ordering1{X(0), X(1), M(0)};
|
||||||
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
||||||
|
|
||||||
// Check that the discrete conditional matches the expected.
|
// Check that the discrete conditional matches the expected.
|
||||||
auto dc1 = bn1->back()->asDiscrete();
|
auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
||||||
|
|
||||||
// Eliminate into BN using a different ordering
|
// Eliminate into BN using a different ordering
|
||||||
|
@ -479,7 +480,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
||||||
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
||||||
|
|
||||||
// Check that the discrete conditional matches the expected.
|
// Check that the discrete conditional matches the expected.
|
||||||
auto dc2 = bn2->back()->asDiscrete();
|
auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
potentials[i] = 1;
|
potentials[i] = 1;
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
// Prune the HybridGaussianConditional
|
// Prune the HybridGaussianConditional
|
||||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
const auto pruned =
|
||||||
|
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||||
}
|
}
|
||||||
|
@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
0, 0, 0.5, 0};
|
0, 0, 0.5, 0};
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
const auto pruned =
|
||||||
|
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||||
|
@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
0, 0, 0.5, 0};
|
0, 0, 0.5, 0};
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
const auto pruned =
|
||||||
|
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
@ -114,10 +115,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||||
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
||||||
|
|
||||||
// Check that factor is discrete and correct
|
// Check that factor is discrete and correct
|
||||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
|
||||||
CHECK(factor);
|
CHECK(factor);
|
||||||
// regression test
|
// regression test
|
||||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -329,7 +330,7 @@ TEST(HybridBayesNet, Switching) {
|
||||||
|
|
||||||
// Check the remaining factor for x1
|
// Check the remaining factor for x1
|
||||||
CHECK(factor_x1);
|
CHECK(factor_x1);
|
||||||
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
|
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
|
||||||
CHECK(phi_x1);
|
CHECK(phi_x1);
|
||||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||||
// We can't really check the error of the decision tree factor phi_x1, because
|
// We can't really check the error of the decision tree factor phi_x1, because
|
||||||
|
@ -650,7 +651,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
mode, std::vector{conditional0, conditional1});
|
mode, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
// Add prior on mode.
|
// Add prior on mode.
|
||||||
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "74/26");
|
expectedBayesNet.emplace_shared<TableDistribution>(mode, "74 26");
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
|
@ -700,11 +701,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
||||||
m1, std::vector{conditional0, conditional1});
|
m1, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
// Add prior on m1.
|
// Add prior on m1.
|
||||||
expectedBayesNet.emplace_shared<DiscreteConditional>(m1, "1/1");
|
expectedBayesNet.emplace_shared<TableDistribution>(m1, "0.188638 0.811362");
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||||
|
|
||||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
|
|
||||||
|
@ -736,7 +737,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
mode, std::vector{conditional0, conditional1});
|
mode, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
// Add prior on mode.
|
// Add prior on mode.
|
||||||
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "23/77");
|
// Since this is the only discrete conditional, it is added as a
|
||||||
|
// TableDistribution.
|
||||||
|
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23 77");
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/geometry/Pose2.h>
|
#include <gtsam/geometry/Pose2.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||||
|
@ -141,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
|
||||||
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
|
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
|
||||||
|
|
||||||
// Test the probability values with regression tests.
|
// Test the probability values with regression tests.
|
||||||
auto discrete = isam[M(1)]->conditional()->asDiscrete();
|
auto discrete = isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
||||||
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
||||||
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
||||||
|
@ -221,16 +222,12 @@ TEST(HybridGaussianISAM, ApproxInference) {
|
||||||
1 1 1 Leaf 0.5
|
1 1 1 Leaf 0.5
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>(
|
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||||
incrementalHybrid[M(0)]->conditional()->inner());
|
incrementalHybrid[M(0)]->conditional()->inner());
|
||||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||||
|
|
||||||
// Get the number of elements which are greater than 0.
|
|
||||||
auto count = [](const double &value, int count) {
|
|
||||||
return value > 0 ? count + 1 : count;
|
|
||||||
};
|
|
||||||
// Check that the number of leaves after pruning is 5.
|
// Check that the number of leaves after pruning is 5.
|
||||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0));
|
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
|
||||||
|
|
||||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||||
// bayes net, at the same positions.
|
// bayes net, at the same positions.
|
||||||
|
@ -477,7 +474,9 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
||||||
|
|
||||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||||
DiscreteFactorGraph discreteGraph;
|
DiscreteFactorGraph discreteGraph;
|
||||||
discreteGraph.push_back(discreteTree);
|
// discreteTree is a TableDistribution, so we convert to
|
||||||
|
// DecisionTreeFactor for the DiscreteFactorGraph
|
||||||
|
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
||||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||||
|
|
||||||
DiscreteValues expected_assignment;
|
DiscreteValues expected_assignment;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
@ -143,8 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
||||||
// Since no measurement on x1, we hedge our bets
|
// Since no measurement on x1, we hedge our bets
|
||||||
// Importance sampling run with 100k samples gives 50.051/49.949
|
// Importance sampling run with 100k samples gives 50.051/49.949
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "50/50");
|
TableDistribution expected(m1, "50 50");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete())));
|
EXPECT(
|
||||||
|
assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>())));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -160,8 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
||||||
// Since we have a measurement on x1, we get a definite result
|
// Since we have a measurement on x1, we get a definite result
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "44.3854/55.6146");
|
TableDistribution expected(m1, "44.3854 55.6146");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||||
|
0.02));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,8 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "48.3158/51.6842");
|
TableDistribution expected(m1, "48.3158 51.6842");
|
||||||
EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002));
|
EXPECT(assert_equal(
|
||||||
|
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -263,8 +267,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "55.396/44.604");
|
TableDistribution expected(m1, "55.396 44.604");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||||
|
0.02));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -340,8 +345,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "51.7762/48.2238");
|
TableDistribution expected(m1, "51.7762 48.2238");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||||
|
0.02));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -355,8 +361,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "49.0762/50.9238");
|
TableDistribution expected(m1, "49.0762 50.9238");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.005));
|
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||||
|
0.05));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,8 +388,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "8.91527/91.0847");
|
TableDistribution expected(m1, "8.91527 91.0847");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||||
|
0.01));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -537,8 +545,8 @@ TEST(HybridGaussianFactorGraph, DifferentCovariances) {
|
||||||
DiscreteValues dv0{{M(1), 0}};
|
DiscreteValues dv0{{M(1), 0}};
|
||||||
DiscreteValues dv1{{M(1), 1}};
|
DiscreteValues dv1{{M(1), 1}};
|
||||||
|
|
||||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
TableDistribution expected_m1(m1, "0.5 0.5");
|
||||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
|
||||||
|
|
||||||
EXPECT(assert_equal(expected_m1, actual_m1));
|
EXPECT(assert_equal(expected_m1, actual_m1));
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/geometry/Pose2.h>
|
#include <gtsam/geometry/Pose2.h>
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
@ -368,10 +369,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
|
||||||
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
|
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
|
||||||
|
|
||||||
// This is now a discreteFactor
|
// This is now a discreteFactor
|
||||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
|
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
|
||||||
CHECK(discreteFactor);
|
CHECK(discreteFactor);
|
||||||
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
||||||
EXPECT(discreteFactor->root_->isLeaf() == false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************
|
/****************************************************************************
|
||||||
|
@ -513,9 +513,10 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
|
||||||
// P(m1)
|
// P(m1)
|
||||||
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
|
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
|
||||||
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
|
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
|
||||||
EXPECT(
|
TableDistribution dtc =
|
||||||
dynamic_pointer_cast<DiscreteConditional>(hybridBayesNet->at(4)->inner())
|
*hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
|
||||||
->equals(*discreteBayesNet.at(1)));
|
EXPECT(DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
|
||||||
|
.equals(*discreteBayesNet.at(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************
|
/****************************************************************************
|
||||||
|
@ -1062,8 +1063,8 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) {
|
||||||
DiscreteValues dv0{{M(1), 0}};
|
DiscreteValues dv0{{M(1), 0}};
|
||||||
DiscreteValues dv1{{M(1), 1}};
|
DiscreteValues dv1{{M(1), 1}};
|
||||||
|
|
||||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
TableDistribution expected_m1(m1, "0.5 0.5");
|
||||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
|
||||||
|
|
||||||
EXPECT(assert_equal(expected_m1, actual_m1));
|
EXPECT(assert_equal(expected_m1, actual_m1));
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/geometry/Pose2.h>
|
#include <gtsam/geometry/Pose2.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
||||||
|
@ -265,16 +266,12 @@ TEST(HybridNonlinearISAM, ApproxInference) {
|
||||||
1 1 1 Leaf 0.5
|
1 1 1 Leaf 0.5
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>(
|
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||||
bayesTree[M(0)]->conditional()->inner());
|
bayesTree[M(0)]->conditional()->inner());
|
||||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||||
|
|
||||||
// Get the number of elements which are greater than 0.
|
|
||||||
auto count = [](const double &value, int count) {
|
|
||||||
return value > 0 ? count + 1 : count;
|
|
||||||
};
|
|
||||||
// Check that the number of leaves after pruning is 5.
|
// Check that the number of leaves after pruning is 5.
|
||||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0));
|
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
|
||||||
|
|
||||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||||
// bayes net, at the same positions.
|
// bayes net, at the same positions.
|
||||||
|
@ -520,12 +517,13 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
||||||
|
|
||||||
// The final discrete graph should not be empty since we have eliminated
|
// The final discrete graph should not be empty since we have eliminated
|
||||||
// all continuous variables.
|
// all continuous variables.
|
||||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete();
|
auto discreteTree =
|
||||||
|
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
|
||||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||||
|
|
||||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||||
DiscreteFactorGraph discreteGraph;
|
DiscreteFactorGraph discreteGraph;
|
||||||
discreteGraph.push_back(discreteTree);
|
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
||||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||||
|
|
||||||
DiscreteValues expected_assignment;
|
DiscreteValues expected_assignment;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
@ -44,6 +45,7 @@ BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
|
||||||
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
|
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
|
||||||
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
|
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
|
||||||
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
|
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(TableDistribution, "gtsam_TableDistribution");
|
||||||
|
|
||||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||||
using ADT = AlgebraicDecisionTree<Key>;
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
|
@ -13,14 +13,14 @@ Author: Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam.symbol_shorthand import C, M, X, Z
|
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import (DiscreteConditional, GaussianConditional,
|
from gtsam import (DiscreteConditional, GaussianConditional, HybridBayesNet,
|
||||||
HybridBayesNet, HybridGaussianConditional,
|
HybridGaussianConditional, HybridGaussianFactor,
|
||||||
HybridGaussianFactor, HybridGaussianFactorGraph,
|
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
|
||||||
HybridValues, JacobianFactor, noiseModel)
|
TableDistribution, noiseModel)
|
||||||
|
from gtsam.symbol_shorthand import C, M, X, Z
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
DEBUG_MARGINALS = False
|
DEBUG_MARGINALS = False
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
self.assertEqual(len(hybridCond.keys()), 2)
|
self.assertEqual(len(hybridCond.keys()), 2)
|
||||||
|
|
||||||
discrete_conditional = hbn.at(hbn.size() - 1).inner()
|
discrete_conditional = hbn.at(hbn.size() - 1).inner()
|
||||||
self.assertIsInstance(discrete_conditional, DiscreteConditional)
|
self.assertIsInstance(discrete_conditional, TableDistribution)
|
||||||
|
|
||||||
def test_optimize(self):
|
def test_optimize(self):
|
||||||
"""Test construction of hybrid factor graph."""
|
"""Test construction of hybrid factor graph."""
|
||||||
|
|
Loading…
Reference in New Issue