Merge pull request #1358 from borglab/hybrid/gaussian-mixture-factor

release/4.3a0
Frank Dellaert 2022-12-30 23:28:36 -05:00 committed by GitHub
commit d0821a57de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 364 additions and 234 deletions

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals); return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
conditional->logNormalizationConstant()};
}); });
return boost::make_shared<GaussianMixtureFactor>( return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods); continuousParentKeys, discreteParentKeys, likelihoods);
} }
/* ************************************************************************* */ /* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
std::set<DiscreteKey> s; std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end()); s.insert(discreteKeys.begin(), discreteKeys.end());
return s; return s;
} }
@ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); const DiscreteValues values(choices);
// Case where the gaussian mixture has the same // Case where the gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
@ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues, double GaussianMixture::error(const HybridValues &values) const {
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues); auto conditional = conditionals_(values.discrete());
return conditional->error(continuousValues); return conditional->error(values.continuous());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -30,6 +30,7 @@
namespace gtsam { namespace gtsam {
class GaussianMixtureFactor; class GaussianMixtureFactor;
class HybridValues;
/** /**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as * @brief A conditional of gaussian mixtures indexed by discrete variables, as
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// Defaut constructor, mainly for serialization. /// Default constructor, mainly for serialization.
GaussianMixture() = default; GaussianMixture() = default;
/** /**
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Standard API /// @name Standard API
/// @{ /// @{
/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()( GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const; const DiscreteValues &discreteValues) const;
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
* @brief Compute the error of this Gaussian Mixture given the continuous * @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment. * values and a discrete assignment.
* *
* @param continuousValues Continuous values at which to compute the error. * @param values Continuous values and discrete assignment.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double * @return double
*/ */
double error(const VectorValues &continuousValues, double error(const HybridValues &values) const override;
const DiscreteValues &discreteValues) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
}; };
/// Return the DiscreteKey vector as a set. /// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys); std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits // traits
template <> template <>

View File

@ -22,6 +22,8 @@
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -29,8 +31,11 @@ namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors) const Mixture &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {} : Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
@ -43,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors: // Check the base and the factors:
return Base::equals(*e, tol) && return Base::equals(*e, tol) &&
factors_.equals(e->factors_, factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
[tol](const GaussianFactor::shared_ptr &f1, const FactorAndConstant &f2) {
const GaussianFactor::shared_ptr &f2) { return f1.factor->equals(*(f2.factor), tol) &&
return f1->equals(*f2, tol); std::abs(f1.constant - f2.constant) < tol;
}); });
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
} else { } else {
factors_.print( factors_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string { [&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf && !gf->empty()) { if (gf && !gf->empty()) {
@ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
} }
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
return factors_; return Mixture(factors_, [](const FactorAndConstant &factor_z) {
return factor_z.factor;
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) { auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor_z.factor);
return result; return result;
}; };
return {factors_, wrap}; return {factors_, wrap};
@ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
[continuousValues](const GaussianFactor::shared_ptr &factor) { return factor_z.error(continuousValues);
return factor->error(continuousValues); };
};
DecisionTree<Key, double> errorTree(factors_, errorFunc); DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree; return errorTree;
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixtureFactor::error( double GaussianMixtureFactor::error(const HybridValues &values) const {
const VectorValues &continuousValues, const FactorAndConstant factor_z = factors_(values.discrete());
const DiscreteValues &discreteValues) const { return factor_z.error(values.continuous());
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
} }
/* *******************************************************************************/
} // namespace gtsam } // namespace gtsam

View File

@ -23,17 +23,15 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
namespace gtsam { namespace gtsam {
class GaussianFactorGraph; class GaussianFactorGraph;
class HybridValues;
// Needed for wrapper. class DiscreteValues;
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; class VectorValues;
/** /**
* @brief Implementation of a discrete conditional mixture factor. * @brief Implementation of a discrete conditional mixture factor.
@ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>; using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// typedef for Decision Tree of Gaussian Factors /// Gaussian factor and log of normalizing constant.
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; struct FactorAndConstant {
sharedFactor factor;
double constant;
// Return error with constant correction.
double error(const VectorValues &values) const {
// Note minus sign: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, hence we subtract the constant here.
return factor->error(values) - constant;
}
// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
@ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables. * @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and * @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities. * their cardinalities.
* @param factors The decision tree of Gaussian Factors stored as the mixture * @param factors The decision tree of Gaussian factors stored as the mixture
* density. * density.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors); const Mixture &factors);
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors_and_z)
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
/** /**
* @brief Construct a new GaussianMixtureFactor object using a vector of * @brief Construct a new GaussianMixtureFactor object using a vector of
@ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors) const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys, : GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {} Mixture(discreteKeys, factors)) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const std::string &s = "GaussianMixtureFactor\n", const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API
/// @{
/// Getter for the underlying Gaussian Factor Decision Tree. /// Getter for the underlying Gaussian Factor Decision Tree.
const Factors &factors(); const Mixture factors() const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/** /**
* @brief Compute the error of this Gaussian Mixture given the continuous * @brief Compute the log-likelihood, including the log-normalizing constant.
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double * @return double
*/ */
double error(const VectorValues &continuousValues, double error(const HybridValues &values) const override;
const DiscreteValues &discreteValues) const;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);
return sum; return sum;
} }
/// @}
}; };
// traits // traits

View File

@ -1,5 +1,5 @@
/* ---------------------------------------------------------------------------- /* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation, * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415 * Atlanta, Georgia 30332-0415
* All Rights Reserved * All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
@ -12,6 +12,7 @@
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Shangjie Xue * @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022 * @date January 2022
*/ */
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues, double HybridBayesNet::error(const HybridValues &values) const {
const DiscreteValues &discreteValues) const { GaussianBayesNet gbn = choose(values.discrete());
GaussianBayesNet gbn = choose(discreteValues); return gbn.error(values.continuous());
return gbn.error(continuousValues);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief 0.5 * sum of squared Mahalanobis distances * @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment. * for a specific discrete assignment.
* *
* @param continuousValues Continuous values at which to compute the error. * @param values Continuous values and discrete assignment.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double * @return double
*/ */
double error(const VectorValues &continuousValues, double error(const HybridValues &values) const;
const DiscreteValues &discreteValues) const;
/** /**
* @brief Compute conditional error for each discrete assignment, * @brief Compute conditional error for each discrete assignment,

View File

@ -52,7 +52,7 @@ namespace gtsam {
* having diamond inheritances, and neutralized the need to change other * having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work. * components of GTSAM to make hybrid elimination work.
* *
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs). * talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
* *
* @ingroup hybrid * @ingroup hybrid
@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional
*/ */
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional
bool equals(const HybridFactor& other, double tol = 1e-9) const override; bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @} /// @}
/// @name Standard Interface
/// @{
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() const {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() const {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() const {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; } boost::shared_ptr<Factor> inner() { return inner_; }
/// Return the error of the underlying conditional.
/// Currently only implemented for Gaussian mixture.
double error(const HybridValues& values) const override {
if (auto gm = asMixture()) {
return gm->error(values);
} else {
throw std::runtime_error(
"HybridConditional::error: only implemented for Gaussian mixture");
}
}
/// @}
private: private:
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;

View File

@ -17,6 +17,7 @@
*/ */
#include <gtsam/hybrid/HybridDiscreteFactor.h> #include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s,
inner_->print("\n", formatter); inner_->print("\n", formatter);
}; };
/* ************************************************************************ */
double HybridDiscreteFactor::error(const HybridValues &values) const {
return -log((*inner_)(values.discrete()));
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -24,10 +24,12 @@
namespace gtsam { namespace gtsam {
class HybridValues;
/** /**
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows * A HybridDiscreteFactor is a thin container for DiscreteFactor, which
* us to hide the implementation of DiscreteFactor and thus avoid diamond * allows us to hide the implementation of DiscreteFactor and thus avoid
* inheritance. * diamond inheritance.
* *
* @ingroup hybrid * @ingroup hybrid
*/ */
@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor /// Return pointer to the internal discrete factor
DiscreteFactor::shared_ptr inner() const { return inner_; } DiscreteFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
}; };
// traits // traits

View File

@ -26,6 +26,8 @@
#include <string> #include <string>
namespace gtsam { namespace gtsam {
class HybridValues;
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
virtual double error(const HybridValues &values) const = 0;
/// True if this is a factor of discrete variables only. /// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; } bool isDiscrete() const { return isDiscrete_; }

View File

@ -16,6 +16,7 @@
*/ */
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/HessianFactor.h> #include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s,
inner_->print("\n", formatter); inner_->print("\n", formatter);
}; };
/* ************************************************************************ */
double HybridGaussianFactor::error(const HybridValues &values) const {
return inner_->error(values.continuous());
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -25,6 +25,7 @@ namespace gtsam {
// Forward declarations // Forward declarations
class JacobianFactor; class JacobianFactor;
class HessianFactor; class HessianFactor;
class HybridValues;
/** /**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor
GaussianFactor::shared_ptr inner() const { return inner_; } GaussianFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
}; };
// traits // traits

View File

@ -55,13 +55,14 @@
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (sum.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
for (auto &f : factors) { for (auto &f : factors) {
if (f->isHybrid()) { if (f->isHybrid()) {
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { // TODO(dellaert): just use a virtual method defined in HybridFactor.
sum = cgmf->add(sum); if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum);
} }
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum); sum = gm->asMixture()->add(sum);
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
const KeySet &continuousSeparator, const KeySet &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) { const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree, // NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete. // only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
@ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}; };
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = GaussianFactorGraph::EliminationResult; using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph) auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
-> GaussianFactorGraph::EliminationResult {
if (graph.empty()) { if (graph.empty()) {
return {nullptr, nullptr}; return {nullptr, {nullptr, 0.0}};
} }
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
@ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
std::pair<boost::shared_ptr<GaussianConditional>, std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>> boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys); conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the // Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional // eliminated GaussianConditional
keysOfEliminated = result.first->keys(); keysOfEliminated = conditional_factor.first->keys();
keysOfSeparator = result.second->keys(); keysOfSeparator = conditional_factor.second->keys();
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate); gttoc_(hybrid_eliminate);
#endif #endif
return result; return {conditional_factor.first, {conditional_factor.second, 0.0}};
}; };
// Perform elimination! // Perform elimination!
@ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults); auto pair = unzip(eliminationResults);
const auto &separatorFactors = pair.second;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto conditional = boost::make_shared<GaussianMixture>(
@ -257,16 +258,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// DiscreteFactor, with the error for each discrete choice. // DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) { if (keysOfSeparator.empty()) {
VectorValues empty_values; VectorValues empty_values;
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { auto factorProb =
if (!factor) { [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
return 0.0; // If nullptr, return 0.0 probability GaussianFactor::shared_ptr factor = factor_z.factor;
} else { if (!factor) {
// This is the probability q(μ) at the MLE point. return 0.0; // If nullptr, return 0.0 probability
double error = } else {
0.5 * std::abs(factor->augmentedInformation().determinant()); // This is the probability q(μ) at the MLE point.
return std::exp(-error); double error =
} 0.5 * std::abs(factor->augmentedInformation().determinant()) +
}; factor_z.constant;
return std::exp(-error);
}
};
DecisionTree<Key, double> fdt(separatorFactors, factorProb); DecisionTree<Key, double> fdt(separatorFactors, factorProb);
auto discreteFactor = auto discreteFactor =
@ -452,6 +456,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Iterate over each factor. // Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) { for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error; AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) { if (factors_.at(idx)->isHybrid()) {
@ -491,38 +496,17 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
} }
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::error( double HybridGaussianFactorGraph::error(const HybridValues &values) const {
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = 0.0; double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) { for (auto &factor : factors_) {
auto factor = factors_.at(idx); error += factor->error(values);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
}
} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
}
}
} }
return error; return error;
} }
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime( double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
const VectorValues &continuousValues, double error = this->error(values);
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
// NOTE: The 0.5 term is handled by each factor // NOTE: The 0.5 term is handled by each factor
return std::exp(-error); return std::exp(-error);
} }

View File

@ -12,7 +12,7 @@
/** /**
* @file HybridGaussianFactorGraph.h * @file HybridGaussianFactorGraph.h
* @brief Linearized Hybrid factor graph that uses type erasure * @brief Linearized Hybrid factor graph that uses type erasure
* @author Fan Jiang, Varun Agrawal * @author Fan Jiang, Varun Agrawal, Frank Dellaert
* @date Mar 11, 2022 * @date Mar 11, 2022
*/ */
@ -38,6 +38,7 @@ class HybridBayesTree;
class HybridJunctionTree; class HybridJunctionTree;
class DecisionTreeFactor; class DecisionTreeFactor;
class JacobianFactor; class JacobianFactor;
class HybridValues;
/** /**
* @brief Main elimination function for HybridGaussianFactorGraph. * @brief Main elimination function for HybridGaussianFactorGraph.
@ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute error given a continuous vector values * @brief Compute error given a continuous vector values
* and a discrete assignment. * and a discrete assignment.
* *
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double * @return double
*/ */
double error(const VectorValues& continuousValues, double error(const HybridValues& values) const;
const DiscreteValues& discreteValues) const;
/** /**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
@ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute the unnormalized posterior probability for a continuous * @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment. * vector values given a specific assignment.
* *
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double * @return double
*/ */
double probPrime(const VectorValues& continuousValues, double probPrime(const HybridValues& values) const;
const DiscreteValues& discreteValues) const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are

View File

@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard Interface
/// @{
NonlinearFactor::shared_ptr inner() const { return inner_; } NonlinearFactor::shared_ptr inner() const { return inner_; }
/// Error for HybridValues is not provided for nonlinear factor.
double error(const HybridValues &values) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error(HybridValues) not implemented.");
}
/// Linearize to a HybridGaussianFactor at the linearization point `c`. /// Linearize to a HybridGaussianFactor at the linearization point `c`.
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const { boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c)); return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
} }
/// @}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor {
factor, continuousValues); factor, continuousValues);
} }
/// Error for HybridValues is not provided for nonlinear hybrid factor.
double error(const HybridValues &values) const override {
throw std::runtime_error(
"MixtureFactor::error(HybridValues) not implemented.");
}
size_t dim() const { size_t dim() const {
// TODO(Varun) // TODO(Varun)
throw std::runtime_error("MixtureFactor::dim not implemented."); throw std::runtime_error("MixtureFactor::dim not implemented.");

View File

@ -183,10 +183,8 @@ class HybridGaussianFactorGraph {
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
// evaluation // evaluation
double error(const gtsam::VectorValues& continuousValues, double error(const gtsam::HybridValues& values) const;
const gtsam::DiscreteValues& discreteValues) const; double probPrime(const gtsam::HybridValues& values) const;
double probPrime(const gtsam::VectorValues& continuousValues,
const gtsam::DiscreteValues& discreteValues) const;
gtsam::HybridBayesNet* eliminateSequential(); gtsam::HybridBayesNet* eliminateSequential();
gtsam::HybridBayesNet* eliminateSequential( gtsam::HybridBayesNet* eliminateSequential(

View File

@ -128,9 +128,9 @@ TEST(GaussianMixture, Error) {
// Regression for non-tree version. // Regression for non-tree version.
DiscreteValues assignment; DiscreteValues assignment;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8);
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}),
1e-8); 1e-8);
} }
@ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) {
const GaussianMixtureFactor::Factors factors( const GaussianMixtureFactor::Factors factors(
gm.conditionals(), gm.conditionals(),
[measurements](const GaussianConditional::shared_ptr& conditional) { [measurements](const GaussianConditional::shared_ptr& conditional) {
return conditional->likelihood(measurements); return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(measurements),
conditional->logNormalizationConstant()};
}); });
const GaussianMixtureFactor expected({X(0)}, {mode}, factors); const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(*factor, expected)); EXPECT(assert_equal(*factor, expected));

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) {
DiscreteValues discreteValues; DiscreteValues discreteValues;
discreteValues[m1.first] = 1; discreteValues[m1.first] = 1;
EXPECT_DOUBLES_EQUAL( EXPECT_DOUBLES_EQUAL(
4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); 4.0, mixtureFactor.error({continuousValues, discreteValues}),
1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) {
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
//TODO(Varun) The expectedAssignment should be 111, not 101 // TODO(Varun) The expectedAssignment should be 111, not 101
DiscreteValues expectedAssignment; DiscreteValues expectedAssignment;
expectedAssignment[M(0)] = 1; expectedAssignment[M(0)] = 1;
expectedAssignment[M(1)] = 0; expectedAssignment[M(1)] = 0;
expectedAssignment[M(2)] = 1; expectedAssignment[M(2)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete())); EXPECT(assert_equal(expectedAssignment, delta.discrete()));
//TODO(Varun) This should be all -Vector1::Ones() // TODO(Varun) This should be all -Vector1::Ones()
VectorValues expectedValues; VectorValues expectedValues;
expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) {
double total_error = 0; double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) { if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), double error = hybridBayesNet->atMixture(idx)->error(
discrete_values); {delta.continuous(), discrete_values});
total_error += error; total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) { } else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) {
} }
EXPECT_DOUBLES_EQUAL( EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error(delta.continuous(), discrete_values), total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
1e-9); 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);

View File

@ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
continue; continue;
} }
double error = graph.error(delta, assignment); double error = graph.error({delta, assignment});
probPrimes.push_back(exp(-error)); probPrimes.push_back(exp(-error));
} }
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
const HybridValues& sample) -> double { const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete(); const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability // Compute in log form for numerical stability
double log_ratio = bayesNet->error(sample.continuous(), assignment) - double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
factorGraph->error(sample.continuous(), assignment); factorGraph->error({sample.continuous(), assignment});
double ratio = exp(-log_ratio); double ratio = exp(-log_ratio);
return ratio; return ratio;
}; };

View File

@ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
HybridBayesNet::shared_ptr hybridBayesNet = HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering); graph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize(); const HybridValues delta = hybridBayesNet->optimize();
double error = graph.error(delta.continuous(), delta.discrete()); const double error = graph.error(delta);
double expected_error = 0.490243199;
// regression
EXPECT(assert_equal(expected_error, error, 1e-9));
double probs = exp(-error);
double expected_probs = graph.probPrime(delta.continuous(), delta.discrete());
// regression // regression
EXPECT(assert_equal(expected_probs, probs, 1e-7)); EXPECT(assert_equal(1.58886, error, 1e-5));
// Real test:
EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7));
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -168,26 +168,30 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
double GaussianConditional::logDeterminant() const { double GaussianConditional::logDeterminant() const {
double logDet; if (get_model()) {
if (this->get_model()) { Vector diag = R().diagonal();
Vector diag = this->R().diagonal(); get_model()->whitenInPlace(diag);
this->get_model()->whitenInPlace(diag); return diag.unaryExpr([](double x) { return log(x); }).sum();
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
} else { } else {
logDet = return R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
} }
return logDet;
} }
/* ************************************************************************* */ /* ************************************************************************* */
// density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma)) // normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
// log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) // log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
double GaussianConditional::logDensity(const VectorValues& x) const { double GaussianConditional::logNormalizationConstant() const {
constexpr double log2pi = 1.8378770664093454835606594728112; constexpr double log2pi = 1.8378770664093454835606594728112;
size_t n = d().size(); size_t n = d().size();
// log det(Sigma)) = - 2.0 * logDeterminant() // log det(Sigma)) = - 2.0 * logDeterminant()
return - error(x) - 0.5 * n * log2pi + logDeterminant(); return - 0.5 * n * log2pi + logDeterminant();
}
/* ************************************************************************* */
// density = k exp(-error(x))
// log = log(k) -error(x) - 0.5 * n*log(2*pi)
double GaussianConditional::logDensity(const VectorValues& x) const {
return logNormalizationConstant() - error(x);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -169,7 +169,7 @@ namespace gtsam {
* *
* @return double * @return double
*/ */
double determinant() const { return exp(this->logDeterminant()); } inline double determinant() const { return exp(logDeterminant()); }
/** /**
* @brief Compute the log determinant of the R matrix. * @brief Compute the log determinant of the R matrix.
@ -184,6 +184,19 @@ namespace gtsam {
*/ */
double logDeterminant() const; double logDeterminant() const;
/**
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
*/
double logNormalizationConstant() const;
/**
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
*/
inline double normalizationConstant() const {
return exp(logNormalizationConstant());
}
/** /**
* Solves a conditional Gaussian and writes the solution into the entries of * Solves a conditional Gaussian and writes the solution into the entries of
* \c x for each frontal variable of the conditional. The parents are * \c x for each frontal variable of the conditional. The parents are

View File

@ -6,7 +6,7 @@ All Rights Reserved
See LICENSE for the license information See LICENSE for the license information
Unit tests for Hybrid Factor Graphs. Unit tests for Hybrid Factor Graphs.
Author: Fan Jiang Author: Fan Jiang, Varun Agrawal, Frank Dellaert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
@ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, GaussianMixtureFactor, GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
HybridGaussianFactorGraph, JacobianFactor, Ordering, HybridGaussianFactorGraph, JacobianFactor, Ordering,
noiseModel) noiseModel)
class TestHybridGaussianFactorGraph(GtsamTestCase): class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridGaussianFactorGraph.""" """Unit tests for HybridGaussianFactorGraph."""
def test_create(self): def test_create(self):
"""Test construction of hybrid factor graph.""" """Test construction of hybrid factor graph."""
model = noiseModel.Unit.Create(3) model = noiseModel.Unit.Create(3)
@ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hv.atDiscrete(C(0)), 1) self.assertEqual(hv.atDiscrete(C(0)), 1)
@staticmethod @staticmethod
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: def tiny(num_measurements: int = 1) -> HybridBayesNet:
""" """
Create a tiny two variable hybrid model which represents Create a tiny two variable hybrid model which represents
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
""" """
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = gtsam.HybridBayesNet() bayesNet = HybridBayesNet()
# Create mode key: 0 is low-noise, 1 is high-noise. # Create mode key: 0 is low-noise, 1 is high-noise.
mode = (M(0), 2) mode = (M(0), 2)
@ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
bayesNet.addGaussian(prior_on_x0) bayesNet.addGaussian(prior_on_x0)
# Add prior on mode. # Add prior on mode.
bayesNet.emplaceDiscrete(mode, "1/1") bayesNet.emplaceDiscrete(mode, "4/6")
return bayesNet return bayesNet
@staticmethod
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
"""Create a factor graph from the Bayes net with sampled measurements.
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
and thus represents the same joint probability as the Bayes net.
"""
fg = HybridGaussianFactorGraph()
num_measurements = bayesNet.size() - 2
for i in range(num_measurements):
conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues()
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(num_measurements))
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
return fg
@classmethod
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000):
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
# Use prior on x0, mode as proposal density.
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
# Allocate space for marginals.
marginals = np.zeros((2,))
# Do importance sampling.
num_measurements = bayesNet.size() - 2
for s in range(N):
proposed = prior.sample()
for i in range(num_measurements):
z_i = sample.at(Z(i))
proposed.insert(Z(i), z_i)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
marginals[proposed.atDiscrete(M(0))] += weight
# print marginals:
marginals /= marginals.sum()
return marginals
def test_tiny(self): def test_tiny(self):
"""Test a tiny two variable hybrid model.""" """Test a tiny two variable hybrid model."""
bayesNet = self.tiny() bayesNet = self.tiny()
sample = bayesNet.sample() sample = bayesNet.sample()
# print(sample) # print(sample)
# Create a factor graph from the Bayes net with sampled measurements. # Estimate marginals using importance sampling.
fg = HybridGaussianFactorGraph() marginals = self.estimate_marginals(bayesNet, sample)
conditional = bayesNet.atMixture(0) # print(f"True mode: {sample.atDiscrete(M(0))}")
measurement = gtsam.VectorValues() # print(f"P(mode=0; z0) = {marginals[0]}")
measurement.insert(Z(0), sample.at(Z(0))) # print(f"P(mode=1; z0) = {marginals[1]}")
factor = conditional.likelihood(measurement)
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(1))
fg.push_back(bayesNet.atDiscrete(2))
# Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1)
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1)
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 3) self.assertEqual(fg.size(), 3)
@staticmethod @staticmethod
def calculate_ratio(bayesNet, fg, sample): def calculate_ratio(bayesNet: HybridBayesNet,
fg: HybridGaussianFactorGraph,
sample: HybridValues):
"""Calculate ratio between Bayes net probability and the factor graph.""" """Calculate ratio between Bayes net probability and the factor graph."""
continuous = gtsam.VectorValues() return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
continuous.insert(X(0), sample.at(X(0)))
return bayesNet.evaluate(sample) / fg.probPrime(
continuous, sample.discrete())
def test_ratio(self): def test_ratio(self):
""" """
@ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
bayesNet = self.tiny(num_measurements=2) bayesNet = self.tiny(num_measurements=2)
# Sample from the Bayes net. # Sample from the Bayes net.
sample: gtsam.HybridValues = bayesNet.sample() sample: HybridValues = bayesNet.sample()
# print(sample) # print(sample)
# Create a factor graph from the Bayes net with sampled measurements. # Estimate marginals using importance sampling.
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` marginals = self.estimate_marginals(bayesNet, sample)
# and thus represents the same joint probability as the Bayes net. # print(f"True mode: {sample.atDiscrete(M(0))}")
fg = HybridGaussianFactorGraph() # print(f"P(mode=0; z0, z1) = {marginals[0]}")
for i in range(2): # print(f"P(mode=1; z0, z1) = {marginals[1]}")
conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues()
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(2))
fg.push_back(bayesNet.atDiscrete(3))
# print(fg) # Check marginals based on sampled mode.
if sample.atDiscrete(M(0)) == 0:
self.assertGreater(marginals[0], marginals[1])
else:
self.assertGreater(marginals[1], marginals[0])
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 4) self.assertEqual(fg.size(), 4)
# Calculate ratio between Bayes net probability and the factor graph: # Calculate ratio between Bayes net probability and the factor graph:
@ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
for i in range(10): for i in range(10):
other = bayesNet.sample() other = bayesNet.sample()
other.update(measurements) other.update(measurements)
# print(other) ratio = self.calculate_ratio(bayesNet, fg, other)
# ratio = self.calculate_ratio(bayesNet, fg, other)
# print(f"Ratio: {ratio}\n") # print(f"Ratio: {ratio}\n")
# self.assertAlmostEqual(ratio, expected_ratio) if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio)
if __name__ == "__main__": if __name__ == "__main__":