Merge pull request #1948 from borglab/hybrid-timing
commit
169523ecc6
|
@ -168,6 +168,12 @@ foreach(build_type "common" ${GTSAM_CMAKE_CONFIGURATION_TYPES})
|
|||
append_config_if_not_empty(GTSAM_COMPILE_DEFINITIONS_PUBLIC ${build_type})
|
||||
endforeach()
|
||||
|
||||
# Check if timing is enabled and add appropriate definition flag
|
||||
if(GTSAM_ENABLE_TIMING AND(NOT ${CMAKE_BUILD_TYPE} EQUAL "Timing"))
|
||||
message(STATUS "Enabling timing for non-timing build")
|
||||
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE "ENABLE_TIMING")
|
||||
endif()
|
||||
|
||||
# Linker flags:
|
||||
set(GTSAM_CMAKE_SHARED_LINKER_FLAGS_TIMING "${CMAKE_SHARED_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
|
||||
set(GTSAM_CMAKE_MODULE_LINKER_FLAGS_TIMING "${CMAKE_MODULE_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
|
||||
|
|
|
@ -33,6 +33,8 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu
|
|||
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
|
||||
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
|
||||
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
|
||||
option(GTSAM_ENABLE_TIMING "Enable the timing tools (gttic/gttoc)" OFF)
|
||||
option(GTSAM_HYBRID_TIMING "Enable the timing of hybrid factor graph machinery" OFF)
|
||||
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
|
||||
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
|
||||
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
|
||||
|
|
|
@ -91,6 +91,7 @@ print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory San
|
|||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
|
||||
print_enabled_config(${GTSAM_ENABLE_TIMING} "Enable timing machinery")
|
||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
|
||||
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
||||
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
||||
|
|
|
@ -31,7 +31,9 @@
|
|||
|
||||
namespace gtsam {
|
||||
namespace internal {
|
||||
|
||||
|
||||
using ChildOrder = FastMap<size_t, std::shared_ptr<TimingOutline>>;
|
||||
|
||||
// a static shared_ptr to TimingOutline with nullptr as the pointer
|
||||
const static std::shared_ptr<TimingOutline> nullTimingOutline;
|
||||
|
||||
|
@ -91,7 +93,6 @@ void TimingOutline::print(const std::string& outline) const {
|
|||
<< n_ << " times, " << wall() << " wall, " << secs() << " children, min: "
|
||||
<< min() << " max: " << max() << ")\n";
|
||||
// Order children
|
||||
typedef FastMap<size_t, std::shared_ptr<TimingOutline> > ChildOrder;
|
||||
ChildOrder childOrder;
|
||||
for(const ChildMap::value_type& child: children_) {
|
||||
childOrder[child.second->myOrder_] = child.second;
|
||||
|
@ -106,6 +107,54 @@ void TimingOutline::print(const std::string& outline) const {
|
|||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void TimingOutline::printCsvHeader(bool addLineBreak) const {
|
||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
||||
// Order is (CPU time, number of times, wall time, time + children in seconds,
|
||||
// min time, max time)
|
||||
std::cout << label_ + " cpu time (s)" << "," << label_ + " #calls" << ","
|
||||
<< label_ + " wall time(s)" << "," << label_ + " subtree time (s)"
|
||||
<< "," << label_ + " min time (s)" << "," << label_ + "max time(s)"
|
||||
<< ",";
|
||||
// Order children
|
||||
ChildOrder childOrder;
|
||||
for (const ChildMap::value_type& child : children_) {
|
||||
childOrder[child.second->myOrder_] = child.second;
|
||||
}
|
||||
// Print children
|
||||
for (const ChildOrder::value_type& order_child : childOrder) {
|
||||
order_child.second->printCsvHeader();
|
||||
}
|
||||
if (addLineBreak) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout.flush();
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void TimingOutline::printCsv(bool addLineBreak) const {
|
||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
||||
// Order is (CPU time, number of times, wall time, time + children in seconds,
|
||||
// min time, max time)
|
||||
std::cout << self() << "," << n_ << "," << wall() << "," << secs() << ","
|
||||
<< min() << "," << max() << ",";
|
||||
// Order children
|
||||
ChildOrder childOrder;
|
||||
for (const ChildMap::value_type& child : children_) {
|
||||
childOrder[child.second->myOrder_] = child.second;
|
||||
}
|
||||
// Print children
|
||||
for (const ChildOrder::value_type& order_child : childOrder) {
|
||||
order_child.second->printCsv(false);
|
||||
}
|
||||
if (addLineBreak) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout.flush();
|
||||
#endif
|
||||
}
|
||||
|
||||
void TimingOutline::print2(const std::string& outline,
|
||||
const double parentTotal) const {
|
||||
#if GTSAM_USE_BOOST_FEATURES
|
||||
|
|
|
@ -199,6 +199,29 @@ namespace gtsam {
|
|||
#endif
|
||||
GTSAM_EXPORT void print(const std::string& outline = "") const;
|
||||
GTSAM_EXPORT void print2(const std::string& outline = "", const double parentTotal = -1.0) const;
|
||||
|
||||
/**
|
||||
* @brief Print the CSV header.
|
||||
* Order is
|
||||
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||
* time, max time)
|
||||
*
|
||||
* @param addLineBreak Flag indicating if a line break should be added at
|
||||
* the end. Only used at the top-leve.
|
||||
*/
|
||||
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
|
||||
|
||||
/**
|
||||
* @brief Print the times recursively from parent to child in CSV format.
|
||||
* For each timing node, the output is
|
||||
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||
* time, max time)
|
||||
*
|
||||
* @param addLineBreak Flag indicating if a line break should be added at
|
||||
* the end. Only used at the top-leve.
|
||||
*/
|
||||
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;
|
||||
|
||||
GTSAM_EXPORT const std::shared_ptr<TimingOutline>&
|
||||
child(size_t child, const std::string& label, const std::weak_ptr<TimingOutline>& thisPtr);
|
||||
GTSAM_EXPORT void tic();
|
||||
|
@ -268,6 +291,14 @@ inline void tictoc_finishedIteration_() {
|
|||
inline void tictoc_print_() {
|
||||
::gtsam::internal::gTimingRoot->print(); }
|
||||
|
||||
// print timing in CSV format
|
||||
inline void tictoc_printCsv_(bool displayHeader = false) {
|
||||
if (displayHeader) {
|
||||
::gtsam::internal::gTimingRoot->printCsvHeader(true);
|
||||
}
|
||||
::gtsam::internal::gTimingRoot->printCsv(true);
|
||||
}
|
||||
|
||||
// print mean and standard deviation
|
||||
inline void tictoc_print2_() {
|
||||
::gtsam::internal::gTimingRoot->print2(); }
|
||||
|
|
|
@ -42,6 +42,9 @@
|
|||
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
|
||||
#cmakedefine GTSAM_DT_MERGING
|
||||
|
||||
// Whether to enable timing in hybrid factor graph machinery
|
||||
#cmakedefine01 GTSAM_HYBRID_TIMING
|
||||
|
||||
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
|
||||
#cmakedefine GTSAM_USE_TBB
|
||||
|
||||
|
|
|
@ -57,6 +57,9 @@ namespace gtsam {
|
|||
|
||||
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||
|
||||
/// Constructor which accepts root pointer
|
||||
AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {}
|
||||
|
||||
// Explicitly non-explicit constructor
|
||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||
|
||||
|
|
|
@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature)
|
|||
/* ************************************************************************** */
|
||||
DiscreteConditional DiscreteConditional::operator*(
|
||||
const DiscreteConditional& other) const {
|
||||
// If the root is a nullptr, we have a TableDistribution
|
||||
// TODO(Varun) Revisit this hack after RSS2025 submission
|
||||
if (!other.root_) {
|
||||
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
|
||||
return dc * (*this);
|
||||
}
|
||||
|
||||
// Take union of frontal keys
|
||||
std::set<Key> newFrontals;
|
||||
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||
|
@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
|
|||
return this->operator()(x.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::shared_ptr DiscreteConditional::max(
|
||||
const Ordering& keys) const {
|
||||
return BaseFactor::max(keys);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteConditional::prune(size_t maxNrAssignments) {
|
||||
// Get as DiscreteConditional so the probabilities are normalized
|
||||
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
|
||||
this->root_ = pruned.root_;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteConditional::negLogConstant() const { return 0.0; }
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
*/
|
||||
size_t sample(const DiscreteValues& parentsValues) const;
|
||||
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/// Single parent version.
|
||||
size_t sample(size_t parent_value) const;
|
||||
|
@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
*/
|
||||
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
||||
|
||||
/**
|
||||
* @brief Create new factor by maximizing over all
|
||||
* values with the same separator.
|
||||
*
|
||||
* @param keys The keys to sum over.
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
*/
|
||||
double negLogConstant() const override;
|
||||
|
||||
/// Prune the conditional
|
||||
virtual void prune(size_t maxNrAssignments);
|
||||
|
||||
/// @}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -118,17 +118,11 @@ namespace gtsam {
|
|||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* @brief Multiply all the `factors`.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
static DiscreteFactor::shared_ptr DiscreteProduct(
|
||||
const DiscreteFactorGraph& factors) {
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DiscreteFactor::shared_ptr product = factors.product();
|
||||
DiscreteFactor::shared_ptr product = this->product();
|
||||
gttoc(product);
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
|
@ -145,7 +139,7 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||
DiscreteFactor::shared_ptr product = factors.scaledProduct();
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
|
@ -223,7 +217,7 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||
DiscreteFactor::shared_ptr product = factors.scaledProduct();
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
|
|
|
@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
/** return product of all factors as a single factor */
|
||||
DiscreteFactor::shared_ptr product() const;
|
||||
|
||||
/**
|
||||
* @brief Return product of all `factors` as a single factor,
|
||||
* which is scaled by the max value to prevent underflow
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
DiscreteFactor::shared_ptr scaledProduct() const;
|
||||
|
||||
/**
|
||||
* Evaluates the factor graph given values, returns the joint probability of
|
||||
* the factor graph given specific instantiation of values
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file TableDistribution.cpp
|
||||
* @date Dec 22, 2024
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/discrete/Ring.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using std::pair;
|
||||
using std::stringstream;
|
||||
using std::vector;
|
||||
namespace gtsam {
|
||||
|
||||
/// Normalize sparse_table
|
||||
static Eigen::SparseVector<double> normalizeSparseTable(
|
||||
const Eigen::SparseVector<double>& sparse_table) {
|
||||
return sparse_table / sparse_table.sum();
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
TableDistribution::TableDistribution(const TableFactor& f)
|
||||
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
|
||||
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
|
||||
f.sum(f.keys().size())))) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
TableDistribution::TableDistribution(const DiscreteKeys& keys,
|
||||
const std::vector<double>& potentials)
|
||||
: BaseConditional(keys.size(), keys, ADT(nullptr)),
|
||||
table_(TableFactor(
|
||||
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
TableDistribution::TableDistribution(const DiscreteKeys& keys,
|
||||
const std::string& potentials)
|
||||
: BaseConditional(keys.size(), keys, ADT(nullptr)),
|
||||
table_(TableFactor(
|
||||
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void TableDistribution::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s << " P( ";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
cout << "):\n";
|
||||
table_.print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
|
||||
auto dtc = dynamic_cast<const TableDistribution*>(&other);
|
||||
if (!dtc) {
|
||||
return false;
|
||||
} else {
|
||||
const DiscreteConditional& f(
|
||||
static_cast<const DiscreteConditional&>(other));
|
||||
return table_.equals(dtc->table_, tol) &&
|
||||
DiscreteConditional::BaseConditional::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const {
|
||||
return table_.sum(nrFrontals);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const {
|
||||
return table_.sum(keys);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const {
|
||||
return table_.max(nrFrontals);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
|
||||
return table_.max(keys);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteFactor::shared_ptr TableDistribution::operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
return table_ / f;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteValues TableDistribution::argmax() const {
|
||||
uint64_t maxIdx = 0;
|
||||
double maxValue = 0.0;
|
||||
|
||||
Eigen::SparseVector<double> sparseTable = table_.sparseTable();
|
||||
|
||||
for (SparseIt it(sparseTable); it; ++it) {
|
||||
if (it.value() > maxValue) {
|
||||
maxIdx = it.index();
|
||||
maxValue = it.value();
|
||||
}
|
||||
}
|
||||
|
||||
return table_.findAssignments(maxIdx);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
void TableDistribution::prune(size_t maxNrAssignments) {
|
||||
table_ = table_.prune(maxNrAssignments);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
DiscreteKeys parentsKeys;
|
||||
for (auto&& [key, _] : parentsValues) {
|
||||
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||
}
|
||||
|
||||
// Get the correct conditional distribution: P(F|S=parentsValues)
|
||||
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"TableDistribution::sample can only be called on single variable "
|
||||
"conditionals");
|
||||
}
|
||||
Key key = firstFrontalKey();
|
||||
size_t nj = cardinality(key);
|
||||
vector<double> p(nj);
|
||||
DiscreteValues frontals;
|
||||
for (size_t value = 0; value < nj; value++) {
|
||||
frontals[key] = value;
|
||||
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
if (p[value] == 1.0) {
|
||||
return value; // shortcut exit
|
||||
}
|
||||
}
|
||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||
return distribution(rng);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,177 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file TableDistribution.h
|
||||
* @date Dec 22, 2024
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Distribution which uses a SparseVector as the internal
|
||||
* representation, similar to the TableFactor.
|
||||
*
|
||||
* This is primarily used in the case when we have a clique in the BayesTree
|
||||
* which consists of all the discrete variables, e.g. in hybrid elimination.
|
||||
*
|
||||
* @ingroup discrete
|
||||
*/
|
||||
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||
private:
|
||||
TableFactor table_;
|
||||
|
||||
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef TableDistribution This; ///< Typedef to this class
|
||||
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef DiscreteConditional
|
||||
BaseConditional; ///< Typedef to our conditional base class
|
||||
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor needed for serialization.
|
||||
TableDistribution() {}
|
||||
|
||||
/// Construct from TableFactor.
|
||||
TableDistribution(const TableFactor& f);
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKeys and std::vector.
|
||||
*/
|
||||
TableDistribution(const DiscreteKeys& keys,
|
||||
const std::vector<double>& potentials);
|
||||
|
||||
/**
|
||||
* Construct from single DiscreteKey and std::vector.
|
||||
*/
|
||||
TableDistribution(const DiscreteKey& key,
|
||||
const std::vector<double>& potentials)
|
||||
: TableDistribution(DiscreteKeys(key), potentials) {}
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKey and std::string.
|
||||
*/
|
||||
TableDistribution(const DiscreteKeys& keys, const std::string& potentials);
|
||||
|
||||
/**
|
||||
* Construct from single DiscreteKey and std::string.
|
||||
*/
|
||||
TableDistribution(const DiscreteKey& key, const std::string& potentials)
|
||||
: TableDistribution(DiscreteKeys(key), potentials) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Table Distribution: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// GTSAM-style equals
|
||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return the underlying TableFactor
|
||||
TableFactor table() const { return table_; }
|
||||
|
||||
using BaseConditional::evaluate; // HybridValues version
|
||||
|
||||
/// Evaluate the conditional given the values.
|
||||
virtual double evaluate(const Assignment<Key>& values) const override {
|
||||
return table_.evaluate(values);
|
||||
}
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
/**
|
||||
* @brief Return assignment that maximizes value.
|
||||
*
|
||||
* @return maximizing assignment for the variables.
|
||||
*/
|
||||
DiscreteValues argmax() const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
*/
|
||||
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/// Prune the conditional
|
||||
virtual void prune(size_t maxNrAssignments) override;
|
||||
|
||||
/// Get a DecisionTreeFactor representation.
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
||||
return table_.toDecisionTreeFactor();
|
||||
}
|
||||
|
||||
/// Get the number of non-zero values.
|
||||
uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
ar& BOOST_SERIALIZATION_NVP(table_);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
// TableDistribution
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<TableDistribution> : public Testable<TableDistribution> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -87,6 +87,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* Convert probability table given as doubles to SparseVector.
|
||||
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||
|
@ -98,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
||||
const std::string& table);
|
||||
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef TableFactor This;
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
|
@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Create a TableFactor that is a subset of this TableFactor
|
||||
TableFactor choose(const DiscreteValues assignments,
|
||||
TableFactor choose(const DiscreteValues parentAssignments,
|
||||
DiscreteKeys parent_keys) const;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
|
|
|
@ -168,6 +168,43 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
|||
std::vector<double> pmf() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
virtual class TableFactor : gtsam::DiscreteFactor {
|
||||
TableFactor();
|
||||
TableFactor(const gtsam::DiscreteKeys& keys,
|
||||
const gtsam::TableFactor& potentials);
|
||||
TableFactor(const gtsam::DiscreteKeys& keys, std::vector<double>& table);
|
||||
TableFactor(const gtsam::DiscreteKeys& keys, string spec);
|
||||
TableFactor(const gtsam::DiscreteKeys& keys,
|
||||
const gtsam::DecisionTreeFactor& dtf);
|
||||
TableFactor(const gtsam::DecisionTreeFactor& dtf);
|
||||
|
||||
void print(string s = "TableFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double error(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
virtual class TableDistribution : gtsam::DiscreteConditional {
|
||||
TableDistribution();
|
||||
TableDistribution(const gtsam::TableFactor& f);
|
||||
TableDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||
TableDistribution(const gtsam::DiscreteKeys& keys, std::vector<double> spec);
|
||||
TableDistribution(const gtsam::DiscreteKeys& keys, string spec);
|
||||
TableDistribution(const gtsam::DiscreteKey& key, string spec);
|
||||
|
||||
void print(string s = "Table Distribution\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
||||
gtsam::TableFactor table() const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
size_t nrValues() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
class DiscreteBayesNet {
|
||||
DiscreteBayesNet();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
|
@ -55,12 +56,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
joint = joint * (*conditional);
|
||||
}
|
||||
|
||||
// Prune the joint. NOTE: again, possibly quite expensive.
|
||||
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
|
||||
|
||||
// Create a the result starting with the pruned joint.
|
||||
// Create the result starting with the pruned joint.
|
||||
HybridBayesNet result;
|
||||
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
|
||||
result.emplace_shared<DiscreteConditional>(joint);
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
result.back()->asDiscrete()->prune(maxNrLeaves);
|
||||
|
||||
// Get pruned discrete probabilities so
|
||||
// we can prune HybridGaussianConditionals.
|
||||
DiscreteConditional pruned = *result.back()->asDiscrete();
|
||||
|
||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
|
@ -126,7 +130,14 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
|
||||
// The number of keys should be small so should not
|
||||
// be expensive to convert to DiscreteConditional.
|
||||
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
|
||||
dtc->toDecisionTreeFactor()));
|
||||
} else {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
|
@ -41,6 +42,22 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
|||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues HybridBayesTree::discreteMaxProduct(
|
||||
const DiscreteFactorGraph& dfg) const {
|
||||
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||
|
||||
// Check type of product, and get as TableFactor for efficiency.
|
||||
TableFactor p;
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||
p = *tf;
|
||||
} else {
|
||||
p = TableFactor(product->toDecisionTreeFactor());
|
||||
}
|
||||
DiscreteValues assignment = TableDistribution(p).argmax();
|
||||
return assignment;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesTree::optimize() const {
|
||||
DiscreteFactorGraph discrete_fg;
|
||||
|
@ -52,8 +69,9 @@ HybridValues HybridBayesTree::optimize() const {
|
|||
|
||||
// The root should be discrete only, we compute the MPE
|
||||
if (root_conditional->isDiscrete()) {
|
||||
discrete_fg.push_back(root_conditional->asDiscrete());
|
||||
mpe = discrete_fg.optimize();
|
||||
auto discrete = root_conditional->asDiscrete<TableDistribution>();
|
||||
discrete_fg.push_back(discrete);
|
||||
mpe = discreteMaxProduct(discrete_fg);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridBayesTree root is not discrete-only. Please check elimination "
|
||||
|
@ -179,16 +197,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
||||
auto prunedDiscreteProbs =
|
||||
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
|
||||
|
||||
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
||||
discreteProbs->root_ = prunedDiscreteProbs.root_;
|
||||
// Imperative pruning
|
||||
prunedDiscreteProbs->prune(maxNrLeaves);
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDiscreteProbs;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
|
||||
DiscreteConditional::shared_ptr prunedDiscreteProbs;
|
||||
HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
||||
|
||||
|
@ -213,7 +232,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
if (!hybridGaussianCond->pruned()) {
|
||||
// Imperative
|
||||
clique->conditional() = std::make_shared<HybridConditional>(
|
||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
||||
hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
|
||||
}
|
||||
}
|
||||
return parentData;
|
||||
|
|
|
@ -115,6 +115,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
/// @}
|
||||
|
||||
private:
|
||||
/// Helper method to compute the max product assignment
|
||||
/// given a DiscreteFactorGraph
|
||||
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph& dfg) const;
|
||||
|
||||
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -166,12 +166,13 @@ class GTSAM_EXPORT HybridConditional
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
* @brief Return conditional as a DiscreteConditional or specified type T.
|
||||
* @return nullptr if not a DiscreteConditional
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr asDiscrete() const {
|
||||
return std::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||
template <typename T = DiscreteConditional>
|
||||
typename T::shared_ptr asDiscrete() const {
|
||||
return std::dynamic_pointer_cast<T>(inner_);
|
||||
}
|
||||
|
||||
/// Get the type-erased pointer to the inner type
|
||||
|
|
|
@ -304,7 +304,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
|||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||
const DecisionTreeFactor &discreteProbs) const {
|
||||
const DiscreteConditional &discreteProbs) const {
|
||||
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||
*/
|
||||
HybridGaussianConditional::shared_ptr prune(
|
||||
const DecisionTreeFactor &discreteProbs) const;
|
||||
const DiscreteConditional &discreteProbs) const;
|
||||
|
||||
/// Return true if the conditional has already been pruned.
|
||||
bool pruned() const { return pruned_; }
|
||||
|
|
|
@ -20,12 +20,13 @@
|
|||
|
||||
#include <gtsam/base/utilities.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
@ -241,29 +242,29 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
/* ************************************************************************ */
|
||||
/**
|
||||
* @brief Take negative log-values, shift them so that the minimum value is 0,
|
||||
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
||||
* and then exponentiate to create a TableFactor (not normalized yet!).
|
||||
*
|
||||
* @param errors DecisionTree of (unnormalized) errors.
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
* @return TableFactor::shared_ptr
|
||||
*/
|
||||
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
||||
static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const AlgebraicDecisionTree<Key> &errors) {
|
||||
double min_log = errors.min();
|
||||
AlgebraicDecisionTree<Key> potentials(
|
||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
static DiscreteFactorGraph CollectDiscreteFactors(
|
||||
const HybridGaussianFactorGraph &factors) {
|
||||
DiscreteFactorGraph dfg;
|
||||
|
||||
for (auto &f : factors) {
|
||||
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
dfg.push_back(df);
|
||||
|
||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||
// In this case, compute a discrete factor from the remaining error.
|
||||
|
@ -282,16 +283,73 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
auto dc = hc->asDiscrete();
|
||||
if (!dc) throwRuntimeError("discreteElimination", dc);
|
||||
dfg.push_back(dc);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(ConvertConditionalToTableFactor);
|
||||
#endif
|
||||
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
|
||||
/// Get the underlying TableFactor
|
||||
dfg.push_back(dtc->table());
|
||||
} else {
|
||||
// Convert DiscreteConditional to TableFactor
|
||||
auto tdc = std::make_shared<TableFactor>(*dc);
|
||||
dfg.push_back(tdc);
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(ConvertConditionalToTableFactor);
|
||||
#endif
|
||||
} else {
|
||||
throwRuntimeError("discreteElimination", f);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
return dfg;
|
||||
}
|
||||
|
||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscrete);
|
||||
#endif
|
||||
// Check if separator is empty.
|
||||
// This is the same as checking if the number of frontal variables
|
||||
// is the same as the number of variables in the DiscreteFactorGraph.
|
||||
// If the separator is empty, we have a clique of all the discrete variables
|
||||
// so we can use the TableFactor for efficiency.
|
||||
if (frontalKeys.size() == dfg.keys().size()) {
|
||||
// Get product factor
|
||||
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
// Check type of product, and get as TableFactor for efficiency.
|
||||
TableFactor p;
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||
p = *tf;
|
||||
} else {
|
||||
p = TableFactor(product->toDecisionTreeFactor());
|
||||
}
|
||||
auto conditional = std::make_shared<TableDistribution>(p);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
||||
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||
|
||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||
|
||||
} else {
|
||||
// Perform sum-product.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscrete);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
@ -319,8 +377,19 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
|||
}
|
||||
};
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteBoundaryErrors);
|
||||
#endif
|
||||
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
|
||||
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteBoundaryErrors);
|
||||
gttic_(DiscreteBoundaryResult);
|
||||
#endif
|
||||
auto result = DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteBoundaryResult);
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
@ -360,12 +429,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
// the discrete separator will be *all* the discrete keys.
|
||||
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(HybridCollectProductFactor);
|
||||
#endif
|
||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||
// decision tree indexed by all discrete keys involved. Just like any hybrid
|
||||
// factor, every assignment also has a scalar error, in this case the sum of
|
||||
// all errors in the graph. This error is assignment-specific and accounts for
|
||||
// any difference in noise models used.
|
||||
HybridGaussianProductFactor productFactor = collectProductFactor();
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(HybridCollectProductFactor);
|
||||
#endif
|
||||
|
||||
// Check if a factor is null
|
||||
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
|
||||
|
@ -393,8 +468,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
return {conditional, conditional->negLogConstant(), factor, scalar};
|
||||
};
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(HybridEliminate);
|
||||
#endif
|
||||
// Perform elimination!
|
||||
const ResultTree eliminationResults(productFactor, eliminate);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(HybridEliminate);
|
||||
#endif
|
||||
|
||||
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
||||
|
|
|
@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal(
|
|||
elimination_ordering, function, std::cref(index));
|
||||
|
||||
if (maxNrLeaves) {
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(HybridBayesTreePrune);
|
||||
#endif
|
||||
bayesTree->prune(*maxNrLeaves);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(HybridBayesTreePrune);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Re-add into Bayes tree data structures
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
||||
|
@ -65,7 +66,14 @@ void HybridNonlinearISAM::reorderRelinearize() {
|
|||
// Obtain the new linearization point
|
||||
const Values newLinPoint = estimate();
|
||||
|
||||
auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete());
|
||||
DiscreteConditional::shared_ptr discreteProbabilities;
|
||||
|
||||
auto discreteRoot = isam_.roots().at(0)->conditional();
|
||||
if (discreteRoot->asDiscrete<TableDistribution>()) {
|
||||
discreteProbabilities = discreteRoot->asDiscrete<TableDistribution>();
|
||||
} else {
|
||||
discreteProbabilities = discreteRoot->asDiscrete();
|
||||
}
|
||||
|
||||
isam_.clear();
|
||||
|
||||
|
@ -73,7 +81,7 @@ void HybridNonlinearISAM::reorderRelinearize() {
|
|||
HybridNonlinearFactorGraph pruned_factors;
|
||||
for (auto&& factor : factors_) {
|
||||
if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
||||
pruned_factors.push_back(nf->prune(discreteProbs));
|
||||
pruned_factors.push_back(nf->prune(*discreteProbabilities));
|
||||
} else {
|
||||
pruned_factors.push_back(factor);
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
|
@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
double midway = mu1 - mu0;
|
||||
auto eliminationResult =
|
||||
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
||||
auto pMid = *eliminationResult->at(0)->asDiscrete();
|
||||
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
|
||||
auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid));
|
||||
|
||||
// Everywhere else, the result should be a sigmoid.
|
||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||
|
@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
// Workflow 1: convert HBN to HFG and solve
|
||||
auto eliminationResult1 =
|
||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
||||
auto posterior1 =
|
||||
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||
|
||||
// Workflow 2: directly specify HFG and solve
|
||||
|
@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
||||
hfg1.push_back(mixing);
|
||||
auto eliminationResult2 = hfg1.eliminateSequential();
|
||||
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
||||
auto posterior2 =
|
||||
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||
}
|
||||
}
|
||||
|
@ -133,13 +136,13 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
// Eliminate the graph!
|
||||
auto eliminationResultMax = gfg.eliminateSequential();
|
||||
|
||||
// Equality of posteriors asserts that the elimination is correct (same ratios
|
||||
// for all modes)
|
||||
// Equality of posteriors asserts that the elimination is correct
|
||||
// (same ratios for all modes)
|
||||
EXPECT(assert_equal(expectedDiscretePosterior,
|
||||
eliminationResultMax->discretePosterior(vv)));
|
||||
|
||||
auto pMax = *eliminationResultMax->at(0)->asDiscrete();
|
||||
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
|
||||
auto pMax = *eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4));
|
||||
|
||||
// Everywhere else, the result should be a bell curve like function.
|
||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||
|
@ -149,7 +152,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
// Workflow 1: convert HBN to HFG and solve
|
||||
auto eliminationResult1 =
|
||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
||||
auto posterior1 =
|
||||
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||
|
||||
// Workflow 2: directly specify HFG and solve
|
||||
|
@ -158,10 +162,12 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
||||
hfg.push_back(mixing);
|
||||
auto eliminationResult2 = hfg.eliminateSequential();
|
||||
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
||||
auto posterior2 =
|
||||
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
|
@ -454,7 +455,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
}
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
auto prunedDecisionTree = joint.prune(maxNrLeaves);
|
||||
DiscreteConditional prunedDecisionTree(joint);
|
||||
prunedDecisionTree.prune(maxNrLeaves);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/geometry/Pose3.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
|
@ -464,14 +465,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
|||
|
||||
// Create expected discrete conditional on m0.
|
||||
DiscreteKey m(M(0), 2);
|
||||
DiscreteConditional expected(m % "0.51341712/1"); // regression
|
||||
TableDistribution expected(m, "0.51341712 1"); // regression
|
||||
|
||||
// Eliminate into BN using one ordering
|
||||
const Ordering ordering1{X(0), X(1), M(0)};
|
||||
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
||||
|
||||
// Check that the discrete conditional matches the expected.
|
||||
auto dc1 = bn1->back()->asDiscrete();
|
||||
auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
||||
|
||||
// Eliminate into BN using a different ordering
|
||||
|
@ -479,7 +480,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
|||
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
||||
|
||||
// Check that the discrete conditional matches the expected.
|
||||
auto dc2 = bn2->back()->asDiscrete();
|
||||
auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
||||
}
|
||||
|
||||
|
|
|
@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
potentials[i] = 1;
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
// Prune the HybridGaussianConditional
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||
}
|
||||
|
@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||
|
@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
@ -114,10 +115,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
|||
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
||||
|
||||
// Check that factor is discrete and correct
|
||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
|
||||
CHECK(factor);
|
||||
// regression test
|
||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
||||
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -329,7 +330,7 @@ TEST(HybridBayesNet, Switching) {
|
|||
|
||||
// Check the remaining factor for x1
|
||||
CHECK(factor_x1);
|
||||
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
|
||||
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
|
||||
CHECK(phi_x1);
|
||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||
// We can't really check the error of the decision tree factor phi_x1, because
|
||||
|
@ -650,7 +651,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
|||
mode, std::vector{conditional0, conditional1});
|
||||
|
||||
// Add prior on mode.
|
||||
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "74/26");
|
||||
expectedBayesNet.emplace_shared<TableDistribution>(mode, "74 26");
|
||||
|
||||
// Test elimination
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
|
@ -700,11 +701,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
|||
m1, std::vector{conditional0, conditional1});
|
||||
|
||||
// Add prior on m1.
|
||||
expectedBayesNet.emplace_shared<DiscreteConditional>(m1, "1/1");
|
||||
expectedBayesNet.emplace_shared<TableDistribution>(m1, "0.188638 0.811362");
|
||||
|
||||
// Test elimination
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
|
||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||
|
||||
|
@ -736,7 +737,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
|||
mode, std::vector{conditional0, conditional1});
|
||||
|
||||
// Add prior on mode.
|
||||
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "23/77");
|
||||
// Since this is the only discrete conditional, it is added as a
|
||||
// TableDistribution.
|
||||
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23 77");
|
||||
|
||||
// Test elimination
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
@ -141,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
|
|||
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
|
||||
|
||||
// Test the probability values with regression tests.
|
||||
auto discrete = isam[M(1)]->conditional()->asDiscrete();
|
||||
auto discrete = isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
||||
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
||||
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
||||
|
@ -221,16 +222,12 @@ TEST(HybridGaussianISAM, ApproxInference) {
|
|||
1 1 1 Leaf 0.5
|
||||
*/
|
||||
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>(
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||
incrementalHybrid[M(0)]->conditional()->inner());
|
||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||
|
||||
// Get the number of elements which are greater than 0.
|
||||
auto count = [](const double &value, int count) {
|
||||
return value > 0 ? count + 1 : count;
|
||||
};
|
||||
// Check that the number of leaves after pruning is 5.
|
||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0));
|
||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
|
||||
|
||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||
// bayes net, at the same positions.
|
||||
|
@ -477,7 +474,9 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
DiscreteFactorGraph discreteGraph;
|
||||
discreteGraph.push_back(discreteTree);
|
||||
// discreteTree is a TableDistribution, so we convert to
|
||||
// DecisionTreeFactor for the DiscreteFactorGraph
|
||||
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||
|
||||
DiscreteValues expected_assignment;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
@ -143,8 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
|||
// Since no measurement on x1, we hedge our bets
|
||||
// Importance sampling run with 100k samples gives 50.051/49.949
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "50/50");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete())));
|
||||
TableDistribution expected(m1, "50 50");
|
||||
EXPECT(
|
||||
assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>())));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -160,8 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
|||
// Since we have a measurement on x1, we get a definite result
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "44.3854/55.6146");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
||||
TableDistribution expected(m1, "44.3854 55.6146");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||
0.02));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -248,8 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "48.3158/51.6842");
|
||||
EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002));
|
||||
TableDistribution expected(m1, "48.3158 51.6842");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -263,8 +267,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "55.396/44.604");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
||||
TableDistribution expected(m1, "55.396 44.604");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||
0.02));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -340,8 +345,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "51.7762/48.2238");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
||||
TableDistribution expected(m1, "51.7762 48.2238");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||
0.02));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -355,8 +361,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "49.0762/50.9238");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.005));
|
||||
TableDistribution expected(m1, "49.0762 50.9238");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||
0.05));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -381,8 +388,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteConditional expected(m1, "8.91527/91.0847");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
||||
TableDistribution expected(m1, "8.91527 91.0847");
|
||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
|
||||
0.01));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -537,8 +545,8 @@ TEST(HybridGaussianFactorGraph, DifferentCovariances) {
|
|||
DiscreteValues dv0{{M(1), 0}};
|
||||
DiscreteValues dv1{{M(1), 1}};
|
||||
|
||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
||||
TableDistribution expected_m1(m1, "0.5 0.5");
|
||||
TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
|
||||
|
||||
EXPECT(assert_equal(expected_m1, actual_m1));
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
@ -368,10 +369,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
|
|||
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
|
||||
|
||||
// This is now a discreteFactor
|
||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
|
||||
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
|
||||
CHECK(discreteFactor);
|
||||
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
||||
EXPECT(discreteFactor->root_->isLeaf() == false);
|
||||
}
|
||||
|
||||
/****************************************************************************
|
||||
|
@ -513,9 +513,10 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
|
|||
// P(m1)
|
||||
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
|
||||
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
|
||||
EXPECT(
|
||||
dynamic_pointer_cast<DiscreteConditional>(hybridBayesNet->at(4)->inner())
|
||||
->equals(*discreteBayesNet.at(1)));
|
||||
TableDistribution dtc =
|
||||
*hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
|
||||
EXPECT(DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
|
||||
.equals(*discreteBayesNet.at(1)));
|
||||
}
|
||||
|
||||
/****************************************************************************
|
||||
|
@ -1062,8 +1063,8 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) {
|
|||
DiscreteValues dv0{{M(1), 0}};
|
||||
DiscreteValues dv1{{M(1), 1}};
|
||||
|
||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
||||
TableDistribution expected_m1(m1, "0.5 0.5");
|
||||
TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
|
||||
|
||||
EXPECT(assert_equal(expected_m1, actual_m1));
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearISAM.h>
|
||||
|
@ -265,16 +266,12 @@ TEST(HybridNonlinearISAM, ApproxInference) {
|
|||
1 1 1 Leaf 0.5
|
||||
*/
|
||||
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>(
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||
bayesTree[M(0)]->conditional()->inner());
|
||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||
|
||||
// Get the number of elements which are greater than 0.
|
||||
auto count = [](const double &value, int count) {
|
||||
return value > 0 ? count + 1 : count;
|
||||
};
|
||||
// Check that the number of leaves after pruning is 5.
|
||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0));
|
||||
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
|
||||
|
||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||
// bayes net, at the same positions.
|
||||
|
@ -520,12 +517,13 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete();
|
||||
auto discreteTree =
|
||||
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
DiscreteFactorGraph discreteGraph;
|
||||
discreteGraph.push_back(discreteTree);
|
||||
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||
|
||||
DiscreteValues expected_assignment;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
|
@ -44,6 +45,7 @@ BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
|
|||
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
|
||||
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
|
||||
BOOST_CLASS_EXPORT_GUID(TableDistribution, "gtsam_TableDistribution");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
|
|
@ -13,14 +13,14 @@ Author: Fan Jiang, Varun Agrawal, Frank Dellaert
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from gtsam.symbol_shorthand import C, M, X, Z
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
import gtsam
|
||||
from gtsam import (DiscreteConditional, GaussianConditional,
|
||||
HybridBayesNet, HybridGaussianConditional,
|
||||
HybridGaussianFactor, HybridGaussianFactorGraph,
|
||||
HybridValues, JacobianFactor, noiseModel)
|
||||
from gtsam import (DiscreteConditional, GaussianConditional, HybridBayesNet,
|
||||
HybridGaussianConditional, HybridGaussianFactor,
|
||||
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
|
||||
TableDistribution, noiseModel)
|
||||
from gtsam.symbol_shorthand import C, M, X, Z
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
DEBUG_MARGINALS = False
|
||||
|
||||
|
@ -51,7 +51,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
self.assertEqual(len(hybridCond.keys()), 2)
|
||||
|
||||
discrete_conditional = hbn.at(hbn.size() - 1).inner()
|
||||
self.assertIsInstance(discrete_conditional, DiscreteConditional)
|
||||
self.assertIsInstance(discrete_conditional, TableDistribution)
|
||||
|
||||
def test_optimize(self):
|
||||
"""Test construction of hybrid factor graph."""
|
||||
|
|
Loading…
Reference in New Issue