Merge pull request #1358 from borglab/hybrid/gaussian-mixture-factor

release/4.3a0
Frank Dellaert 2022-12-30 23:28:36 -05:00 committed by GitHub
commit d0821a57de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 364 additions and 234 deletions

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
conditional->logNormalizationConstant()};
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
s.insert(discreteKeys.begin(), discreteKeys.end());
return s;
}
@ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
const DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
@ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
}
/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous());
}
} // namespace gtsam

View File

@ -30,6 +30,7 @@
namespace gtsam {
class GaussianMixtureFactor;
class HybridValues;
/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Constructors
/// @{
/// Defaut constructor, mainly for serialization.
/// Default constructor, mainly for serialization.
GaussianMixture() = default;
/**
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Standard API
/// @{
/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const override;
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
};
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits
template <>

View File

@ -22,6 +22,8 @@
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
@ -29,8 +31,11 @@ namespace gtsam {
/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
@ -43,10 +48,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const GaussianFactor::shared_ptr &f1,
const GaussianFactor::shared_ptr &f2) {
return f1->equals(*f2, tol);
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
const FactorAndConstant &f2) {
return f1.factor->equals(*(f2.factor), tol) &&
std::abs(f1.constant - f2.constant) < tol;
});
}
@ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
[&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
@ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
}
/* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
return factors_;
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
return factor_z.factor;
});
}
/* *******************************************************************************/
@ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor);
result.push_back(factor_z.factor);
return result;
};
return {factors_, wrap};
@ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousValues);
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
return factor_z.error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete());
return factor_z.error(values.continuous());
}
/* *******************************************************************************/
} // namespace gtsam

View File

@ -23,17 +23,15 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
namespace gtsam {
class GaussianFactorGraph;
// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
class HybridValues;
class DiscreteValues;
class VectorValues;
/**
* @brief Implementation of a discrete conditional mixture factor.
@ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// typedef for Decision Tree of Gaussian Factors
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
/// Gaussian factor and log of normalizing constant.
struct FactorAndConstant {
sharedFactor factor;
double constant;
// Return error with constant correction.
double error(const VectorValues &values) const {
// Note minus sign: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, hence we subtract the constant here.
return factor->error(values) - constant;
}
// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
@ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities.
* @param factors The decision tree of Gaussian Factors stored as the mixture
* @param factors The decision tree of Gaussian factors stored as the mixture
* density.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors);
const Mixture &factors);
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors_and_z)
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
/**
* @brief Construct a new GaussianMixtureFactor object using a vector of
@ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}
Mixture(discreteKeys, factors)) {}
/// @}
/// @name Testable
@ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard API
/// @{
/// Getter for the underlying Gaussian Factor Decision Tree.
const Factors &factors();
const Mixture factors() const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
/// @}
};
// traits

View File

@ -1,5 +1,5 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
@ -12,6 +12,7 @@
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022
*/
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
}
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = choose(discreteValues);
return gbn.error(continuousValues);
double HybridBayesNet::error(const HybridValues &values) const {
GaussianBayesNet gbn = choose(values.discrete());
return gbn.error(values.continuous());
}
/* ************************************************************************* */

View File

@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const;
/**
* @brief Compute conditional error for each discrete assignment,

View File

@ -52,7 +52,7 @@ namespace gtsam {
* having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work.
*
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
*
* @ingroup hybrid
@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional
*/
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// @}
/// @name Testable
/// @{
@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() const {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() const {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() const {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }
/// Return the error of the underlying conditional.
/// Currently only implemented for Gaussian mixture.
double error(const HybridValues& values) const override {
if (auto gm = asMixture()) {
return gm->error(values);
} else {
throw std::runtime_error(
"HybridConditional::error: only implemented for Gaussian mixture");
}
}
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;

View File

@ -17,6 +17,7 @@
*/
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/make_shared.hpp>
@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s,
inner_->print("\n", formatter);
};
/* ************************************************************************ */
double HybridDiscreteFactor::error(const HybridValues &values) const {
return -log((*inner_)(values.discrete()));
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -24,10 +24,12 @@
namespace gtsam {
class HybridValues;
/**
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
* us to hide the implementation of DiscreteFactor and thus avoid diamond
* inheritance.
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
* allows us to hide the implementation of DiscreteFactor and thus avoid
* diamond inheritance.
*
* @ingroup hybrid
*/
@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor
DiscreteFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
};
// traits

View File

@ -26,6 +26,8 @@
#include <string>
namespace gtsam {
class HybridValues;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @name Standard Interface
/// @{
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
virtual double error(const HybridValues &values) const = 0;
/// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; }

View File

@ -16,6 +16,7 @@
*/
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h>
@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s,
inner_->print("\n", formatter);
};
/* ************************************************************************ */
double HybridGaussianFactor::error(const HybridValues &values) const {
return inner_->error(values.continuous());
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -25,6 +25,7 @@ namespace gtsam {
// Forward declarations
class JacobianFactor;
class HessianFactor;
class HybridValues;
/**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor
GaussianFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
};
// traits

View File

@ -55,13 +55,14 @@
namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it.
// If the decision tree is not initialized, then initialize it.
if (sum.empty()) {
GaussianFactorGraph result;
result.push_back(factor);
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
for (auto &f : factors) {
if (f->isHybrid()) {
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = cgmf->add(sum);
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum);
}
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum);
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
const KeySet &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete.
// only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
@ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = GaussianFactorGraph::EliminationResult;
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
// This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph)
-> GaussianFactorGraph::EliminationResult {
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
if (graph.empty()) {
return {nullptr, nullptr};
return {nullptr, {nullptr, 0.0}};
}
#ifdef HYBRID_TIMING
@ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional
keysOfEliminated = result.first->keys();
keysOfSeparator = result.second->keys();
keysOfEliminated = conditional_factor.first->keys();
keysOfSeparator = conditional_factor.second->keys();
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return result;
return {conditional_factor.first, {conditional_factor.second, 0.0}};
};
// Perform elimination!
@ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults);
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
const auto &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>(
@ -257,13 +258,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) {
VectorValues empty_values;
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
GaussianFactor::shared_ptr factor = factor_z.factor;
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
// This is the probability q(μ) at the MLE point.
double error =
0.5 * std::abs(factor->augmentedInformation().determinant());
0.5 * std::abs(factor->augmentedInformation().determinant()) +
factor_z.constant;
return std::exp(-error);
}
};
@ -452,6 +456,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) {
@ -491,38 +496,17 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
}
} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
}
}
for (auto &factor : factors_) {
error += factor->error(values);
}
return error;
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
double error = this->error(values);
// NOTE: The 0.5 term is handled by each factor
return std::exp(-error);
}

View File

@ -12,7 +12,7 @@
/**
* @file HybridGaussianFactorGraph.h
* @brief Linearized Hybrid factor graph that uses type erasure
* @author Fan Jiang, Varun Agrawal
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
* @date Mar 11, 2022
*/
@ -38,6 +38,7 @@ class HybridBayesTree;
class HybridJunctionTree;
class DecisionTreeFactor;
class JacobianFactor;
class HybridValues;
/**
* @brief Main elimination function for HybridGaussianFactorGraph.
@ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double
*/
double error(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
double error(const HybridValues& values) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
@ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double
*/
double probPrime(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
double probPrime(const HybridValues& values) const;
/**
* @brief Return a Colamd constrained ordering where the discrete keys are

View File

@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
NonlinearFactor::shared_ptr inner() const { return inner_; }
/// Error for HybridValues is not provided for nonlinear factor.
double error(const HybridValues &values) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error(HybridValues) not implemented.");
}
/// Linearize to a HybridGaussianFactor at the linearization point `c`.
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
}
/// @}
};
} // namespace gtsam

View File

@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor {
factor, continuousValues);
}
/// Error for HybridValues is not provided for nonlinear hybrid factor.
double error(const HybridValues &values) const override {
throw std::runtime_error(
"MixtureFactor::error(HybridValues) not implemented.");
}
size_t dim() const {
// TODO(Varun)
throw std::runtime_error("MixtureFactor::dim not implemented.");

View File

@ -183,10 +183,8 @@ class HybridGaussianFactorGraph {
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
// evaluation
double error(const gtsam::VectorValues& continuousValues,
const gtsam::DiscreteValues& discreteValues) const;
double probPrime(const gtsam::VectorValues& continuousValues,
const gtsam::DiscreteValues& discreteValues) const;
double error(const gtsam::HybridValues& values) const;
double probPrime(const gtsam::HybridValues& values) const;
gtsam::HybridBayesNet* eliminateSequential();
gtsam::HybridBayesNet* eliminateSequential(

View File

@ -128,9 +128,9 @@ TEST(GaussianMixture, Error) {
// Regression for non-tree version.
DiscreteValues assignment;
assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8);
assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment),
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}),
1e-8);
}
@ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) {
const GaussianMixtureFactor::Factors factors(
gm.conditionals(),
[measurements](const GaussianConditional::shared_ptr& conditional) {
return conditional->likelihood(measurements);
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(measurements),
conditional->logNormalizationConstant()};
});
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(*factor, expected));

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) {
DiscreteValues discreteValues;
discreteValues[m1.first] = 1;
EXPECT_DOUBLES_EQUAL(
4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9);
4.0, mixtureFactor.error({continuousValues, discreteValues}),
1e-9);
}
/* ************************************************************************* */

View File

@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) {
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
discrete_values);
double error = hybridBayesNet->atMixture(idx)->error(
{delta.continuous(), discrete_values});
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) {
}
EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);

View File

@ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
continue;
}
double error = graph.error(delta, assignment);
double error = graph.error({delta, assignment});
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability
double log_ratio = bayesNet->error(sample.continuous(), assignment) -
factorGraph->error(sample.continuous(), assignment);
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
factorGraph->error({sample.continuous(), assignment});
double ratio = exp(-log_ratio);
return ratio;
};

View File

@ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
double error = graph.error(delta.continuous(), delta.discrete());
double expected_error = 0.490243199;
// regression
EXPECT(assert_equal(expected_error, error, 1e-9));
double probs = exp(-error);
double expected_probs = graph.probPrime(delta.continuous(), delta.discrete());
const HybridValues delta = hybridBayesNet->optimize();
const double error = graph.error(delta);
// regression
EXPECT(assert_equal(expected_probs, probs, 1e-7));
EXPECT(assert_equal(1.58886, error, 1e-5));
// Real test:
EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7));
}
/* ****************************************************************************/

View File

@ -168,26 +168,30 @@ namespace gtsam {
/* ************************************************************************* */
double GaussianConditional::logDeterminant() const {
double logDet;
if (this->get_model()) {
Vector diag = this->R().diagonal();
this->get_model()->whitenInPlace(diag);
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
if (get_model()) {
Vector diag = R().diagonal();
get_model()->whitenInPlace(diag);
return diag.unaryExpr([](double x) { return log(x); }).sum();
} else {
logDet =
this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
return R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
}
return logDet;
}
/* ************************************************************************* */
// density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma))
// log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
double GaussianConditional::logDensity(const VectorValues& x) const {
// normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
// log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
double GaussianConditional::logNormalizationConstant() const {
constexpr double log2pi = 1.8378770664093454835606594728112;
size_t n = d().size();
// log det(Sigma)) = - 2.0 * logDeterminant()
return - error(x) - 0.5 * n * log2pi + logDeterminant();
return - 0.5 * n * log2pi + logDeterminant();
}
/* ************************************************************************* */
// density = k exp(-error(x))
// log = log(k) -error(x) - 0.5 * n*log(2*pi)
double GaussianConditional::logDensity(const VectorValues& x) const {
return logNormalizationConstant() - error(x);
}
/* ************************************************************************* */

View File

@ -169,7 +169,7 @@ namespace gtsam {
*
* @return double
*/
double determinant() const { return exp(this->logDeterminant()); }
inline double determinant() const { return exp(logDeterminant()); }
/**
* @brief Compute the log determinant of the R matrix.
@ -184,6 +184,19 @@ namespace gtsam {
*/
double logDeterminant() const;
/**
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
*/
double logNormalizationConstant() const;
/**
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
*/
inline double normalizationConstant() const {
return exp(logNormalizationConstant());
}
/**
* Solves a conditional Gaussian and writes the solution into the entries of
* \c x for each frontal variable of the conditional. The parents are

View File

@ -6,7 +6,7 @@ All Rights Reserved
See LICENSE for the license information
Unit tests for Hybrid Factor Graphs.
Author: Fan Jiang
Author: Fan Jiang, Varun Agrawal, Frank Dellaert
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
@ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, GaussianMixtureFactor,
GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
HybridGaussianFactorGraph, JacobianFactor, Ordering,
noiseModel)
class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridGaussianFactorGraph."""
def test_create(self):
"""Test construction of hybrid factor graph."""
model = noiseModel.Unit.Create(3)
@ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hv.atDiscrete(C(0)), 1)
@staticmethod
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
def tiny(num_measurements: int = 1) -> HybridBayesNet:
"""
Create a tiny two variable hybrid model which represents
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
"""
# Create hybrid Bayes net.
bayesNet = gtsam.HybridBayesNet()
bayesNet = HybridBayesNet()
# Create mode key: 0 is low-noise, 1 is high-noise.
mode = (M(0), 2)
@ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
bayesNet.addGaussian(prior_on_x0)
# Add prior on mode.
bayesNet.emplaceDiscrete(mode, "1/1")
bayesNet.emplaceDiscrete(mode, "4/6")
return bayesNet
@staticmethod
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
"""Create a factor graph from the Bayes net with sampled measurements.
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
and thus represents the same joint probability as the Bayes net.
"""
fg = HybridGaussianFactorGraph()
num_measurements = bayesNet.size() - 2
for i in range(num_measurements):
conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues()
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(num_measurements))
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
return fg
@classmethod
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000):
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
# Use prior on x0, mode as proposal density.
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
# Allocate space for marginals.
marginals = np.zeros((2,))
# Do importance sampling.
num_measurements = bayesNet.size() - 2
for s in range(N):
proposed = prior.sample()
for i in range(num_measurements):
z_i = sample.at(Z(i))
proposed.insert(Z(i), z_i)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
marginals[proposed.atDiscrete(M(0))] += weight
# print marginals:
marginals /= marginals.sum()
return marginals
def test_tiny(self):
"""Test a tiny two variable hybrid model."""
bayesNet = self.tiny()
sample = bayesNet.sample()
# print(sample)
# Create a factor graph from the Bayes net with sampled measurements.
fg = HybridGaussianFactorGraph()
conditional = bayesNet.atMixture(0)
measurement = gtsam.VectorValues()
measurement.insert(Z(0), sample.at(Z(0)))
factor = conditional.likelihood(measurement)
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(1))
fg.push_back(bayesNet.atDiscrete(2))
# Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample)
# print(f"True mode: {sample.atDiscrete(M(0))}")
# print(f"P(mode=0; z0) = {marginals[0]}")
# print(f"P(mode=1; z0) = {marginals[1]}")
# Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1)
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1)
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 3)
@staticmethod
def calculate_ratio(bayesNet, fg, sample):
def calculate_ratio(bayesNet: HybridBayesNet,
fg: HybridGaussianFactorGraph,
sample: HybridValues):
"""Calculate ratio between Bayes net probability and the factor graph."""
continuous = gtsam.VectorValues()
continuous.insert(X(0), sample.at(X(0)))
return bayesNet.evaluate(sample) / fg.probPrime(
continuous, sample.discrete())
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
def test_ratio(self):
"""
@ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
bayesNet = self.tiny(num_measurements=2)
# Sample from the Bayes net.
sample: gtsam.HybridValues = bayesNet.sample()
sample: HybridValues = bayesNet.sample()
# print(sample)
# Create a factor graph from the Bayes net with sampled measurements.
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
# and thus represents the same joint probability as the Bayes net.
fg = HybridGaussianFactorGraph()
for i in range(2):
conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues()
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(2))
fg.push_back(bayesNet.atDiscrete(3))
# Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample)
# print(f"True mode: {sample.atDiscrete(M(0))}")
# print(f"P(mode=0; z0, z1) = {marginals[0]}")
# print(f"P(mode=1; z0, z1) = {marginals[1]}")
# print(fg)
# Check marginals based on sampled mode.
if sample.atDiscrete(M(0)) == 0:
self.assertGreater(marginals[0], marginals[1])
else:
self.assertGreater(marginals[1], marginals[0])
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 4)
# Calculate ratio between Bayes net probability and the factor graph:
@ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
for i in range(10):
other = bayesNet.sample()
other.update(measurements)
# print(other)
# ratio = self.calculate_ratio(bayesNet, fg, other)
ratio = self.calculate_ratio(bayesNet, fg, other)
# print(f"Ratio: {ratio}\n")
# self.assertAlmostEqual(ratio, expected_ratio)
if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio)
if __name__ == "__main__":