Merge pull request #1948 from borglab/hybrid-timing

release/4.3a0
Varun Agrawal 2025-01-08 12:45:26 -05:00 committed by GitHub
commit 169523ecc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 792 additions and 120 deletions

View File

@ -168,6 +168,12 @@ foreach(build_type "common" ${GTSAM_CMAKE_CONFIGURATION_TYPES})
append_config_if_not_empty(GTSAM_COMPILE_DEFINITIONS_PUBLIC ${build_type})
endforeach()
# Check if timing is enabled and add appropriate definition flag
if(GTSAM_ENABLE_TIMING AND(NOT ${CMAKE_BUILD_TYPE} EQUAL "Timing"))
message(STATUS "Enabling timing for non-timing build")
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE "ENABLE_TIMING")
endif()
# Linker flags:
set(GTSAM_CMAKE_SHARED_LINKER_FLAGS_TIMING "${CMAKE_SHARED_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
set(GTSAM_CMAKE_MODULE_LINKER_FLAGS_TIMING "${CMAKE_MODULE_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")

View File

@ -33,6 +33,8 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
option(GTSAM_ENABLE_TIMING "Enable the timing tools (gttic/gttoc)" OFF)
option(GTSAM_HYBRID_TIMING "Enable the timing of hybrid factor graph machinery" OFF)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)

View File

@ -91,6 +91,7 @@ print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory San
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
print_enabled_config(${GTSAM_ENABLE_TIMING} "Enable timing machinery")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")

View File

@ -31,7 +31,9 @@
namespace gtsam {
namespace internal {
using ChildOrder = FastMap<size_t, std::shared_ptr<TimingOutline>>;
// a static shared_ptr to TimingOutline with nullptr as the pointer
const static std::shared_ptr<TimingOutline> nullTimingOutline;
@ -91,7 +93,6 @@ void TimingOutline::print(const std::string& outline) const {
<< n_ << " times, " << wall() << " wall, " << secs() << " children, min: "
<< min() << " max: " << max() << ")\n";
// Order children
typedef FastMap<size_t, std::shared_ptr<TimingOutline> > ChildOrder;
ChildOrder childOrder;
for(const ChildMap::value_type& child: children_) {
childOrder[child.second->myOrder_] = child.second;
@ -106,6 +107,54 @@ void TimingOutline::print(const std::string& outline) const {
#endif
}
/* ************************************************************************* */
void TimingOutline::printCsvHeader(bool addLineBreak) const {
#ifdef GTSAM_USE_BOOST_FEATURES
// Order is (CPU time, number of times, wall time, time + children in seconds,
// min time, max time)
std::cout << label_ + " cpu time (s)" << "," << label_ + " #calls" << ","
<< label_ + " wall time(s)" << "," << label_ + " subtree time (s)"
<< "," << label_ + " min time (s)" << "," << label_ + "max time(s)"
<< ",";
// Order children
ChildOrder childOrder;
for (const ChildMap::value_type& child : children_) {
childOrder[child.second->myOrder_] = child.second;
}
// Print children
for (const ChildOrder::value_type& order_child : childOrder) {
order_child.second->printCsvHeader();
}
if (addLineBreak) {
std::cout << std::endl;
}
std::cout.flush();
#endif
}
/* ************************************************************************* */
void TimingOutline::printCsv(bool addLineBreak) const {
#ifdef GTSAM_USE_BOOST_FEATURES
// Order is (CPU time, number of times, wall time, time + children in seconds,
// min time, max time)
std::cout << self() << "," << n_ << "," << wall() << "," << secs() << ","
<< min() << "," << max() << ",";
// Order children
ChildOrder childOrder;
for (const ChildMap::value_type& child : children_) {
childOrder[child.second->myOrder_] = child.second;
}
// Print children
for (const ChildOrder::value_type& order_child : childOrder) {
order_child.second->printCsv(false);
}
if (addLineBreak) {
std::cout << std::endl;
}
std::cout.flush();
#endif
}
void TimingOutline::print2(const std::string& outline,
const double parentTotal) const {
#if GTSAM_USE_BOOST_FEATURES

View File

@ -199,6 +199,29 @@ namespace gtsam {
#endif
GTSAM_EXPORT void print(const std::string& outline = "") const;
GTSAM_EXPORT void print2(const std::string& outline = "", const double parentTotal = -1.0) const;
/**
* @brief Print the CSV header.
* Order is
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
*/
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
/**
* @brief Print the times recursively from parent to child in CSV format.
* For each timing node, the output is
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
*/
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;
GTSAM_EXPORT const std::shared_ptr<TimingOutline>&
child(size_t child, const std::string& label, const std::weak_ptr<TimingOutline>& thisPtr);
GTSAM_EXPORT void tic();
@ -268,6 +291,14 @@ inline void tictoc_finishedIteration_() {
inline void tictoc_print_() {
::gtsam::internal::gTimingRoot->print(); }
// print timing in CSV format
inline void tictoc_printCsv_(bool displayHeader = false) {
if (displayHeader) {
::gtsam::internal::gTimingRoot->printCsvHeader(true);
}
::gtsam::internal::gTimingRoot->printCsv(true);
}
// print mean and standard deviation
inline void tictoc_print2_() {
::gtsam::internal::gTimingRoot->print2(); }

View File

@ -42,6 +42,9 @@
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING
// Whether to enable timing in hybrid factor graph machinery
#cmakedefine01 GTSAM_HYBRID_TIMING
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB

View File

@ -57,6 +57,9 @@ namespace gtsam {
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
/// Constructor which accepts root pointer
AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {}
// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}

View File

@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature)
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const {
// If the root is a nullptr, we have a TableDistribution
// TODO(Varun) Revisit this hack after RSS2025 submission
if (!other.root_) {
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
return dc * (*this);
}
// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->operator()(x.discrete());
}
/* ************************************************************************* */
DiscreteFactor::shared_ptr DiscreteConditional::max(
const Ordering& keys) const {
return BaseFactor::max(keys);
}
/* ************************************************************************* */
void DiscreteConditional::prune(size_t maxNrAssignments) {
// Get as DiscreteConditional so the probabilities are normalized
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
this->root_ = pruned.root_;
}
/* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { return 0.0; }

View File

@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
size_t sample(const DiscreteValues& parentsValues) const;
virtual size_t sample(const DiscreteValues& parentsValues) const;
/// Single parent version.
size_t sample(size_t parent_value) const;
@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional
*/
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
/**
* @brief Create new factor by maximizing over all
* values with the same separator.
*
* @param keys The keys to sum over.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// @}
/// @name Advanced Interface
/// @{
@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional
*/
double negLogConstant() const override;
/// Prune the conditional
virtual void prune(size_t maxNrAssignments);
/// @}
protected:

View File

@ -118,17 +118,11 @@ namespace gtsam {
// }
// }
/**
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
/* ************************************************************************ */
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
// PRODUCT: multiply all factors
gttic(product);
DiscreteFactor::shared_ptr product = factors.product();
DiscreteFactor::shared_ptr product = this->product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
@ -145,7 +139,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = factors.scaledProduct();
// max out frontals, this is the factor on the separator
gttic(max);
@ -223,7 +217,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = factors.scaledProduct();
// sum out frontals, this is the factor on the separator
gttic(sum);

View File

@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** return product of all factors as a single factor */
DiscreteFactor::shared_ptr product() const;
/**
* @brief Return product of all `factors` as a single factor,
* which is scaled by the max value to prevent underflow
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
DiscreteFactor::shared_ptr scaledProduct() const;
/**
* Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values

View File

@ -0,0 +1,174 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TableDistribution.cpp
* @date Dec 22, 2024
* @author Varun Agrawal
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridValues.h>
#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
namespace gtsam {
/// Normalize sparse_table
static Eigen::SparseVector<double> normalizeSparseTable(
const Eigen::SparseVector<double>& sparse_table) {
return sparse_table / sparse_table.sum();
}
/* ************************************************************************** */
TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
f.sum(f.keys().size())))) {}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::string& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}
/* ************************************************************************** */
void TableDistribution::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
cout << "):\n";
table_.print("", formatter);
cout << endl;
}
/* ************************************************************************** */
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
auto dtc = dynamic_cast<const TableDistribution*>(&other);
if (!dtc) {
return false;
} else {
const DiscreteConditional& f(
static_cast<const DiscreteConditional&>(other));
return table_.equals(dtc->table_, tol) &&
DiscreteConditional::BaseConditional::equals(f, tol);
}
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const {
return table_.sum(nrFrontals);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const {
return table_.sum(keys);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const {
return table_.max(nrFrontals);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
return table_.max(keys);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::operator/(
const DiscreteFactor::shared_ptr& f) const {
return table_ / f;
}
/* ************************************************************************ */
DiscreteValues TableDistribution::argmax() const {
uint64_t maxIdx = 0;
double maxValue = 0.0;
Eigen::SparseVector<double> sparseTable = table_.sparseTable();
for (SparseIt it(sparseTable); it; ++it) {
if (it.value() > maxValue) {
maxIdx = it.index();
maxValue = it.value();
}
}
return table_.findAssignments(maxIdx);
}
/* ****************************************************************************/
void TableDistribution::prune(size_t maxNrAssignments) {
table_ = table_.prune(maxNrAssignments);
}
/* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)});
}
// Get the correct conditional distribution: P(F|S=parentsValues)
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
throw std::invalid_argument(
"TableDistribution::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
}
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}
} // namespace gtsam

View File

@ -0,0 +1,177 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TableDistribution.h
* @date Dec 22, 2024
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include <memory>
#include <string>
#include <vector>
namespace gtsam {
/**
* Distribution which uses a SparseVector as the internal
* representation, similar to the TableFactor.
*
* This is primarily used in the case when we have a clique in the BayesTree
* which consists of all the discrete variables, e.g. in hybrid elimination.
*
* @ingroup discrete
*/
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
private:
TableFactor table_;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
public:
// typedefs needed to play nice with gtsam
typedef TableDistribution This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DiscreteConditional
BaseConditional; ///< Typedef to our conditional base class
using Values = DiscreteValues; ///< backwards compatibility
/// @name Standard Constructors
/// @{
/// Default constructor needed for serialization.
TableDistribution() {}
/// Construct from TableFactor.
TableDistribution(const TableFactor& f);
/**
* Construct from DiscreteKeys and std::vector.
*/
TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials);
/**
* Construct from single DiscreteKey and std::vector.
*/
TableDistribution(const DiscreteKey& key,
const std::vector<double>& potentials)
: TableDistribution(DiscreteKeys(key), potentials) {}
/**
* Construct from DiscreteKey and std::string.
*/
TableDistribution(const DiscreteKeys& keys, const std::string& potentials);
/**
* Construct from single DiscreteKey and std::string.
*/
TableDistribution(const DiscreteKey& key, const std::string& potentials)
: TableDistribution(DiscreteKeys(key), potentials) {}
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(
const std::string& s = "Table Distribution: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// GTSAM-style equals
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return the underlying TableFactor
TableFactor table() const { return table_; }
using BaseConditional::evaluate; // HybridValues version
/// Evaluate the conditional given the values.
virtual double evaluate(const Assignment<Key>& values) const override {
return table_.evaluate(values);
}
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/**
* @brief Return assignment that maximizes value.
*
* @return maximizing assignment for the variables.
*/
DiscreteValues argmax() const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
virtual size_t sample(const DiscreteValues& parentsValues) const override;
/// @}
/// @name Advanced Interface
/// @{
/// Prune the conditional
virtual void prune(size_t maxNrAssignments) override;
/// Get a DecisionTreeFactor representation.
DecisionTreeFactor toDecisionTreeFactor() const override {
return table_.toDecisionTreeFactor();
}
/// Get the number of non-zero values.
uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
/// @}
private:
#if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(table_);
}
#endif
};
// TableDistribution
// traits
template <>
struct traits<TableDistribution> : public Testable<TableDistribution> {};
} // namespace gtsam

View File

@ -87,6 +87,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
public:
/**
* Convert probability table given as doubles to SparseVector.
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
@ -98,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::string& table);
public:
// typedefs needed to play nice with gtsam
typedef TableFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DecisionTreeFactor toDecisionTreeFactor() const override;
/// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments,
TableFactor choose(const DiscreteValues parentAssignments,
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values

View File

@ -168,6 +168,43 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
std::vector<double> pmf() const;
};
#include <gtsam/discrete/TableFactor.h>
virtual class TableFactor : gtsam::DiscreteFactor {
TableFactor();
TableFactor(const gtsam::DiscreteKeys& keys,
const gtsam::TableFactor& potentials);
TableFactor(const gtsam::DiscreteKeys& keys, std::vector<double>& table);
TableFactor(const gtsam::DiscreteKeys& keys, string spec);
TableFactor(const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor& dtf);
TableFactor(const gtsam::DecisionTreeFactor& dtf);
void print(string s = "TableFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double error(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/TableDistribution.h>
virtual class TableDistribution : gtsam::DiscreteConditional {
TableDistribution();
TableDistribution(const gtsam::TableFactor& f);
TableDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
TableDistribution(const gtsam::DiscreteKeys& keys, std::vector<double> spec);
TableDistribution(const gtsam::DiscreteKeys& keys, string spec);
TableDistribution(const gtsam::DiscreteKey& key, string spec);
void print(string s = "Table Distribution\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::TableFactor table() const;
double evaluate(const gtsam::DiscreteValues& values) const;
size_t nrValues() const;
};
#include <gtsam/discrete/DiscreteBayesNet.h>
class DiscreteBayesNet {
DiscreteBayesNet();

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
@ -55,12 +56,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
joint = joint * (*conditional);
}
// Prune the joint. NOTE: again, possibly quite expensive.
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
// Create a the result starting with the pruned joint.
// Create the result starting with the pruned joint.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
result.emplace_shared<DiscreteConditional>(joint);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
result.back()->asDiscrete()->prune(maxNrLeaves);
// Get pruned discrete probabilities so
// we can prune HybridGaussianConditionals.
DiscreteConditional pruned = *result.back()->asDiscrete();
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
@ -126,7 +130,14 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
// The number of keys should be small so should not
// be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
dtc->toDecisionTreeFactor()));
} else {
discrete_fg.push_back(conditional->asDiscrete());
}
}
}

View File

@ -20,6 +20,7 @@
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
@ -41,6 +42,22 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}
/* ************************************************************************* */
DiscreteValues HybridBayesTree::discreteMaxProduct(
const DiscreteFactorGraph& dfg) const {
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
// Check type of product, and get as TableFactor for efficiency.
TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf;
} else {
p = TableFactor(product->toDecisionTreeFactor());
}
DiscreteValues assignment = TableDistribution(p).argmax();
return assignment;
}
/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
DiscreteFactorGraph discrete_fg;
@ -52,8 +69,9 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
discrete_fg.push_back(root_conditional->asDiscrete());
mpe = discrete_fg.optimize();
auto discrete = root_conditional->asDiscrete<TableDistribution>();
discrete_fg.push_back(discrete);
mpe = discreteMaxProduct(discrete_fg);
} else {
throw std::runtime_error(
"HybridBayesTree root is not discrete-only. Please check elimination "
@ -179,16 +197,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
auto prunedDiscreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
discreteProbs->root_ = prunedDiscreteProbs.root_;
// Imperative pruning
prunedDiscreteProbs->prune(maxNrLeaves);
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
DiscreteConditional::shared_ptr prunedDiscreteProbs;
HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteProbs(prunedDiscreteProbs) {}
@ -213,7 +232,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (!hybridGaussianCond->pruned()) {
// Imperative
clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
}
}
return parentData;

View File

@ -115,6 +115,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/// @}
private:
/// Helper method to compute the max product assignment
/// given a DiscreteFactorGraph
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph& dfg) const;
#if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;

View File

@ -166,12 +166,13 @@ class GTSAM_EXPORT HybridConditional
}
/**
* @brief Return conditional as a DiscreteConditional
* @brief Return conditional as a DiscreteConditional or specified type T.
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() const {
return std::dynamic_pointer_cast<DiscreteConditional>(inner_);
template <typename T = DiscreteConditional>
typename T::shared_ptr asDiscrete() const {
return std::dynamic_pointer_cast<T>(inner_);
}
/// Get the type-erased pointer to the inner type

View File

@ -304,7 +304,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/* *******************************************************************************/
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DecisionTreeFactor &discreteProbs) const {
const DiscreteConditional &discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return Shared pointer to possibly a pruned HybridGaussianConditional
*/
HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const;
const DiscreteConditional &discreteProbs) const;
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }

View File

@ -20,12 +20,13 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
@ -241,29 +242,29 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
/**
* @brief Take negative log-values, shift them so that the minimum value is 0,
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
* and then exponentiate to create a TableFactor (not normalized yet!).
*
* @param errors DecisionTree of (unnormalized) errors.
* @return DecisionTreeFactor::shared_ptr
* @return TableFactor::shared_ptr
*/
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
static TableFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
return std::make_shared<TableFactor>(discreteKeys, potentials);
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
static DiscreteFactorGraph CollectDiscreteFactors(
const HybridGaussianFactorGraph &factors) {
DiscreteFactorGraph dfg;
for (auto &f : factors) {
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df);
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute a discrete factor from the remaining error.
@ -282,16 +283,73 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(dc);
#if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor);
#endif
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
/// Get the underlying TableFactor
dfg.push_back(dtc->table());
} else {
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
dfg.push_back(tdc);
}
#if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor);
#endif
} else {
throwRuntimeError("discreteElimination", f);
}
}
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
return dfg;
}
return {std::make_shared<HybridConditional>(result.first), result.second};
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// Check if separator is empty.
// This is the same as checking if the number of frontal variables
// is the same as the number of variables in the DiscreteFactorGraph.
// If the separator is empty, we have a clique of all the discrete variables
// so we can use the TableFactor for efficiency.
if (frontalKeys.size() == dfg.keys().size()) {
// Get product factor
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
// Check type of product, and get as TableFactor for efficiency.
TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf;
} else {
p = TableFactor(product->toDecisionTreeFactor());
}
auto conditional = std::make_shared<TableDistribution>(p);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum};
} else {
// Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second};
}
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
}
/* ************************************************************************ */
@ -319,8 +377,19 @@ static std::shared_ptr<Factor> createDiscreteFactor(
}
};
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteBoundaryErrors);
#endif
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
return DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryErrors);
gttic_(DiscreteBoundaryResult);
#endif
auto result = DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryResult);
#endif
return result;
}
/* *******************************************************************************/
@ -360,12 +429,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// the discrete separator will be *all* the discrete keys.
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
#if GTSAM_HYBRID_TIMING
gttic_(HybridCollectProductFactor);
#endif
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. Just like any hybrid
// factor, every assignment also has a scalar error, in this case the sum of
// all errors in the graph. This error is assignment-specific and accounts for
// any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor();
#if GTSAM_HYBRID_TIMING
gttoc_(HybridCollectProductFactor);
#endif
// Check if a factor is null
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
@ -393,8 +468,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
return {conditional, conditional->negLogConstant(), factor, scalar};
};
#if GTSAM_HYBRID_TIMING
gttic_(HybridEliminate);
#endif
// Perform elimination!
const ResultTree eliminationResults(productFactor, eliminate);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridEliminate);
#endif
// If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor

View File

@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal(
elimination_ordering, function, std::cref(index));
if (maxNrLeaves) {
#if GTSAM_HYBRID_TIMING
gttic_(HybridBayesTreePrune);
#endif
bayesTree->prune(*maxNrLeaves);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridBayesTreePrune);
#endif
}
// Re-add into Bayes tree data structures

View File

@ -15,6 +15,7 @@
* @author Varun Agrawal
*/
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h>
@ -65,7 +66,14 @@ void HybridNonlinearISAM::reorderRelinearize() {
// Obtain the new linearization point
const Values newLinPoint = estimate();
auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete());
DiscreteConditional::shared_ptr discreteProbabilities;
auto discreteRoot = isam_.roots().at(0)->conditional();
if (discreteRoot->asDiscrete<TableDistribution>()) {
discreteProbabilities = discreteRoot->asDiscrete<TableDistribution>();
} else {
discreteProbabilities = discreteRoot->asDiscrete();
}
isam_.clear();
@ -73,7 +81,7 @@ void HybridNonlinearISAM::reorderRelinearize() {
HybridNonlinearFactorGraph pruned_factors;
for (auto&& factor : factors_) {
if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
pruned_factors.push_back(nf->prune(discreteProbs));
pruned_factors.push_back(nf->prune(*discreteProbabilities));
} else {
pruned_factors.push_back(factor);
}

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double midway = mu1 - mu0;
auto eliminationResult =
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
auto pMid = *eliminationResult->at(0)->asDiscrete();
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid));
// Everywhere else, the result should be a sigmoid.
for (const double shift : {-4, -2, 0, 2, 4}) {
@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
// Workflow 1: convert HBN to HFG and solve
auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve
@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
hfg1.push_back(mixing);
auto eliminationResult2 = hfg1.eliminateSequential();
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}
@ -133,13 +136,13 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
// Eliminate the graph!
auto eliminationResultMax = gfg.eliminateSequential();
// Equality of posteriors asserts that the elimination is correct (same ratios
// for all modes)
// Equality of posteriors asserts that the elimination is correct
// (same ratios for all modes)
EXPECT(assert_equal(expectedDiscretePosterior,
eliminationResultMax->discretePosterior(vv)));
auto pMax = *eliminationResultMax->at(0)->asDiscrete();
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
auto pMax = *eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4));
// Everywhere else, the result should be a bell curve like function.
for (const double shift : {-4, -2, 0, 2, 4}) {
@ -149,7 +152,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
// Workflow 1: convert HBN to HFG and solve
auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve
@ -158,10 +162,12 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
hfg.push_back(mixing);
auto eliminationResult2 = hfg.eliminateSequential();
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -20,6 +20,7 @@
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
@ -454,7 +455,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
}
size_t maxNrLeaves = 3;
auto prunedDecisionTree = joint.prune(maxNrLeaves);
DiscreteConditional prunedDecisionTree(joint);
prunedDecisionTree.prune(maxNrLeaves);
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,

View File

@ -16,6 +16,7 @@
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/geometry/Pose3.h>
#include <gtsam/hybrid/HybridBayesNet.h>
@ -464,14 +465,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
// Create expected discrete conditional on m0.
DiscreteKey m(M(0), 2);
DiscreteConditional expected(m % "0.51341712/1"); // regression
TableDistribution expected(m, "0.51341712 1"); // regression
// Eliminate into BN using one ordering
const Ordering ordering1{X(0), X(1), M(0)};
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete();
auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering
@ -479,7 +480,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete();
auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc2, 1e-9));
}

View File

@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
potentials[i] = 1;
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional
const auto pruned = hgc.prune(decisionTreeFactor);
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
}
@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor);
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor);
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());

View File

@ -25,6 +25,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h>
@ -114,10 +115,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
CHECK(factor);
// regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
}
/* ************************************************************************* */
@ -329,7 +330,7 @@ TEST(HybridBayesNet, Switching) {
// Check the remaining factor for x1
CHECK(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because
@ -650,7 +651,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
mode, std::vector{conditional0, conditional1});
// Add prior on mode.
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "74/26");
expectedBayesNet.emplace_shared<TableDistribution>(mode, "74 26");
// Test elimination
const auto posterior = fg.eliminateSequential();
@ -700,11 +701,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
m1, std::vector{conditional0, conditional1});
// Add prior on m1.
expectedBayesNet.emplace_shared<DiscreteConditional>(m1, "1/1");
expectedBayesNet.emplace_shared<TableDistribution>(m1, "0.188638 0.811362");
// Test elimination
const auto posterior = fg.eliminateSequential();
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
@ -736,7 +737,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
mode, std::vector{conditional0, conditional1});
// Add prior on mode.
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "23/77");
// Since this is the only discrete conditional, it is added as a
// TableDistribution.
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23 77");
// Test elimination
const auto posterior = fg.eliminateSequential();

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
@ -141,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
// Test the probability values with regression tests.
auto discrete = isam[M(1)]->conditional()->asDiscrete();
auto discrete = isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
@ -221,16 +222,12 @@ TEST(HybridGaussianISAM, ApproxInference) {
1 1 1 Leaf 0.5
*/
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>(
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
incrementalHybrid[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
// Get the number of elements which are greater than 0.
auto count = [](const double &value, int count) {
return value > 0 ? count + 1 : count;
};
// Check that the number of leaves after pruning is 5.
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0));
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions.
@ -477,7 +474,9 @@ TEST(HybridGaussianISAM, NonTrivial) {
// Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph;
discreteGraph.push_back(discreteTree);
// discreteTree is a TableDistribution, so we convert to
// DecisionTreeFactor for the DiscreteFactorGraph
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
DiscreteValues optimal_assignment = discreteGraph.optimize();
DiscreteValues expected_assignment;

View File

@ -22,6 +22,7 @@
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
@ -143,8 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since no measurement on x1, we hedge our bets
// Importance sampling run with 100k samples gives 50.051/49.949
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "50/50");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete())));
TableDistribution expected(m1, "50 50");
EXPECT(
assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>())));
}
{
@ -160,8 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since we have a measurement on x1, we get a definite result
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "44.3854/55.6146");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
TableDistribution expected(m1, "44.3854 55.6146");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.02));
}
}
@ -248,8 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "48.3158/51.6842");
EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002));
TableDistribution expected(m1, "48.3158 51.6842");
EXPECT(assert_equal(
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()), 0.02));
}
{
@ -263,8 +267,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "55.396/44.604");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
TableDistribution expected(m1, "55.396 44.604");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.02));
}
}
@ -340,8 +345,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "51.7762/48.2238");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
TableDistribution expected(m1, "51.7762 48.2238");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.02));
}
{
@ -355,8 +361,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "49.0762/50.9238");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.005));
TableDistribution expected(m1, "49.0762 50.9238");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.05));
}
}
@ -381,8 +388,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "8.91527/91.0847");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
TableDistribution expected(m1, "8.91527 91.0847");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.01));
}
/* ************************************************************************* */
@ -537,8 +545,8 @@ TEST(HybridGaussianFactorGraph, DifferentCovariances) {
DiscreteValues dv0{{M(1), 0}};
DiscreteValues dv1{{M(1), 1}};
DiscreteConditional expected_m1(m1, "0.5/0.5");
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
TableDistribution expected_m1(m1, "0.5 0.5");
TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
EXPECT(assert_equal(expected_m1, actual_m1));
}

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
@ -368,10 +369,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
}
/****************************************************************************
@ -513,9 +513,10 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
// P(m1)
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
EXPECT(
dynamic_pointer_cast<DiscreteConditional>(hybridBayesNet->at(4)->inner())
->equals(*discreteBayesNet.at(1)));
TableDistribution dtc =
*hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
EXPECT(DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
.equals(*discreteBayesNet.at(1)));
}
/****************************************************************************
@ -1062,8 +1063,8 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) {
DiscreteValues dv0{{M(1), 0}};
DiscreteValues dv1{{M(1), 1}};
DiscreteConditional expected_m1(m1, "0.5/0.5");
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
TableDistribution expected_m1(m1, "0.5 0.5");
TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
EXPECT(assert_equal(expected_m1, actual_m1));
}

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h>
@ -265,16 +266,12 @@ TEST(HybridNonlinearISAM, ApproxInference) {
1 1 1 Leaf 0.5
*/
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>(
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
bayesTree[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
// Get the number of elements which are greater than 0.
auto count = [](const double &value, int count) {
return value > 0 ? count + 1 : count;
};
// Check that the number of leaves after pruning is 5.
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0));
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions.
@ -520,12 +517,13 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete();
auto discreteTree =
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph;
discreteGraph.push_back(discreteTree);
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
DiscreteValues optimal_assignment = discreteGraph.optimize();
DiscreteValues expected_assignment;

View File

@ -18,6 +18,7 @@
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
@ -44,6 +45,7 @@ BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
BOOST_CLASS_EXPORT_GUID(TableDistribution, "gtsam_TableDistribution");
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;

View File

@ -13,14 +13,14 @@ Author: Fan Jiang, Varun Agrawal, Frank Dellaert
import unittest
import numpy as np
from gtsam.symbol_shorthand import C, M, X, Z
from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import (DiscreteConditional, GaussianConditional,
HybridBayesNet, HybridGaussianConditional,
HybridGaussianFactor, HybridGaussianFactorGraph,
HybridValues, JacobianFactor, noiseModel)
from gtsam import (DiscreteConditional, GaussianConditional, HybridBayesNet,
HybridGaussianConditional, HybridGaussianFactor,
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
TableDistribution, noiseModel)
from gtsam.symbol_shorthand import C, M, X, Z
from gtsam.utils.test_case import GtsamTestCase
DEBUG_MARGINALS = False
@ -51,7 +51,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(len(hybridCond.keys()), 2)
discrete_conditional = hbn.at(hbn.size() - 1).inner()
self.assertIsInstance(discrete_conditional, DiscreteConditional)
self.assertIsInstance(discrete_conditional, TableDistribution)
def test_optimize(self):
"""Test construction of hybrid factor graph."""