Merge pull request #1358 from borglab/hybrid/gaussian-mixture-factor
commit
d0821a57de
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
|
|
@ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
|||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->likelihood(frontals);
|
||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return GaussianMixtureFactor::FactorAndConstant{
|
||||
conditional->likelihood(frontals),
|
||||
conditional->logNormalizationConstant()};
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
s.insert(discreteKeys.begin(), discreteKeys.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
|
|
@ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
|||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
const DiscreteValues values(choices);
|
||||
|
||||
// Case where the gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
|
|
@ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
double GaussianMixture::error(const HybridValues &values) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(discreteValues);
|
||||
return conditional->error(continuousValues);
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->error(values.continuous());
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
class GaussianMixtureFactor;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
||||
|
|
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Defaut constructor, mainly for serialization.
|
||||
/// Default constructor, mainly for serialization.
|
||||
GaussianMixture() = default;
|
||||
|
||||
/**
|
||||
|
|
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
||||
GaussianConditional::shared_ptr operator()(
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
|
|
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
double error(const HybridValues &values) const override;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||
|
|
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
};
|
||||
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
|
||||
|
||||
// traits
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@
|
|||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
@ -29,8 +31,11 @@ namespace gtsam {
|
|||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const Factors &factors)
|
||||
: Base(continuousKeys, discreteKeys), factors_(factors) {}
|
||||
const Mixture &factors)
|
||||
: Base(continuousKeys, discreteKeys),
|
||||
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
|
||||
return FactorAndConstant{gf, 0.0};
|
||||
}) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||
|
|
@ -43,10 +48,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
|||
|
||||
// Check the base and the factors:
|
||||
return Base::equals(*e, tol) &&
|
||||
factors_.equals(e->factors_,
|
||||
[tol](const GaussianFactor::shared_ptr &f1,
|
||||
const GaussianFactor::shared_ptr &f2) {
|
||||
return f1->equals(*f2, tol);
|
||||
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
|
||||
const FactorAndConstant &f2) {
|
||||
return f1.factor->equals(*(f2.factor), tol) &&
|
||||
std::abs(f1.constant - f2.constant) < tol;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
} else {
|
||||
factors_.print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
|
||||
[&](const FactorAndConstant &gf_z) -> std::string {
|
||||
auto gf = gf_z.factor;
|
||||
RedirectCout rd;
|
||||
std::cout << ":\n";
|
||||
if (gf && !gf->empty()) {
|
||||
|
|
@ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
|
||||
return factors_;
|
||||
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
|
||||
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
|
||||
return factor_z.factor;
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
@ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
|||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||
const {
|
||||
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
|
||||
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
result.push_back(factor_z.factor);
|
||||
return result;
|
||||
};
|
||||
return {factors_, wrap};
|
||||
|
|
@ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
|||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianFactor::shared_ptr &factor) {
|
||||
return factor->error(continuousValues);
|
||||
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
|
||||
return factor_z.error(continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto factor = factors_(discreteValues);
|
||||
return factor->error(continuousValues);
|
||||
double GaussianMixtureFactor::error(const HybridValues &values) const {
|
||||
const FactorAndConstant factor_z = factors_(values.discrete());
|
||||
return factor_z.error(values.continuous());
|
||||
}
|
||||
/* *******************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -23,17 +23,15 @@
|
|||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianFactorGraph;
|
||||
|
||||
// Needed for wrapper.
|
||||
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
|
||||
class HybridValues;
|
||||
class DiscreteValues;
|
||||
class VectorValues;
|
||||
|
||||
/**
|
||||
* @brief Implementation of a discrete conditional mixture factor.
|
||||
|
|
@ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||
|
||||
/// typedef for Decision Tree of Gaussian Factors
|
||||
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
||||
/// Gaussian factor and log of normalizing constant.
|
||||
struct FactorAndConstant {
|
||||
sharedFactor factor;
|
||||
double constant;
|
||||
|
||||
// Return error with constant correction.
|
||||
double error(const VectorValues &values) const {
|
||||
// Note minus sign: constant is log of normalization constant for probabilities.
|
||||
// Errors is the negative log-likelihood, hence we subtract the constant here.
|
||||
return factor->error(values) - constant;
|
||||
}
|
||||
|
||||
// Check pointer equality.
|
||||
bool operator==(const FactorAndConstant &other) const {
|
||||
return factor == other.factor && constant == other.constant;
|
||||
}
|
||||
};
|
||||
|
||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||
using Factors = DecisionTree<Key, FactorAndConstant>;
|
||||
using Mixture = DecisionTree<Key, sharedFactor>;
|
||||
|
||||
private:
|
||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||
|
|
@ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* @param continuousKeys A vector of keys representing continuous variables.
|
||||
* @param discreteKeys A vector of keys representing discrete variables and
|
||||
* their cardinalities.
|
||||
* @param factors The decision tree of Gaussian Factors stored as the mixture
|
||||
* @param factors The decision tree of Gaussian factors stored as the mixture
|
||||
* density.
|
||||
*/
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const Factors &factors);
|
||||
const Mixture &factors);
|
||||
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const Factors &factors_and_z)
|
||||
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
|
||||
|
||||
/**
|
||||
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
||||
|
|
@ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
*/
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
||||
const std::vector<sharedFactor> &factors)
|
||||
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
||||
Factors(discreteKeys, factors)) {}
|
||||
Mixture(discreteKeys, factors)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
|
@ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
const std::string &s = "GaussianMixtureFactor\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
/// @}
|
||||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
/// Getter for the underlying Gaussian Factor Decision Tree.
|
||||
const Factors &factors();
|
||||
const Mixture factors() const;
|
||||
|
||||
/**
|
||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||
|
|
@ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
double error(const HybridValues &values) const override;
|
||||
|
||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||
sum = factor.add(sum);
|
||||
return sum;
|
||||
}
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
* @author Fan Jiang
|
||||
* @author Varun Agrawal
|
||||
* @author Shangjie Xue
|
||||
* @author Frank Dellaert
|
||||
* @date January 2022
|
||||
*/
|
||||
|
||||
|
|
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
GaussianBayesNet gbn = choose(discreteValues);
|
||||
return gbn.error(continuousValues);
|
||||
double HybridBayesNet::error(const HybridValues &values) const {
|
||||
GaussianBayesNet gbn = choose(values.discrete());
|
||||
return gbn.error(values.continuous());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||
* for a specific discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues Discrete assignment for a specific mode sequence.
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
double error(const HybridValues &values) const;
|
||||
|
||||
/**
|
||||
* @brief Compute conditional error for each discrete assignment,
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ namespace gtsam {
|
|||
* having diamond inheritances, and neutralized the need to change other
|
||||
* components of GTSAM to make hybrid elimination work.
|
||||
*
|
||||
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
|
||||
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
|
||||
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
||||
*
|
||||
* @ingroup hybrid
|
||||
|
|
@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional
|
|||
*/
|
||||
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianMixture
|
||||
* @return nullptr if not a mixture
|
||||
* @return GaussianMixture::shared_ptr otherwise
|
||||
*/
|
||||
GaussianMixture::shared_ptr asMixture() {
|
||||
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianConditional
|
||||
* @return nullptr if not a GaussianConditional
|
||||
* @return GaussianConditional::shared_ptr otherwise
|
||||
*/
|
||||
GaussianConditional::shared_ptr asGaussian() {
|
||||
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
* @return nullptr if not a DiscreteConditional
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr asDiscrete() {
|
||||
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
|
@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional
|
|||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianMixture
|
||||
* @return nullptr if not a mixture
|
||||
* @return GaussianMixture::shared_ptr otherwise
|
||||
*/
|
||||
GaussianMixture::shared_ptr asMixture() const {
|
||||
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianConditional
|
||||
* @return nullptr if not a GaussianConditional
|
||||
* @return GaussianConditional::shared_ptr otherwise
|
||||
*/
|
||||
GaussianConditional::shared_ptr asGaussian() const {
|
||||
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
* @return nullptr if not a DiscreteConditional
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr asDiscrete() const {
|
||||
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||
}
|
||||
|
||||
/// Get the type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||
|
||||
/// Return the error of the underlying conditional.
|
||||
/// Currently only implemented for Gaussian mixture.
|
||||
double error(const HybridValues& values) const override {
|
||||
if (auto gm = asMixture()) {
|
||||
return gm->error(values);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: only implemented for Gaussian mixture");
|
||||
}
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
|
|
@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s,
|
|||
inner_->print("\n", formatter);
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridDiscreteFactor::error(const HybridValues &values) const {
|
||||
return -log((*inner_)(values.discrete()));
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -24,10 +24,12 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
|
||||
* us to hide the implementation of DiscreteFactor and thus avoid diamond
|
||||
* inheritance.
|
||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
|
||||
* allows us to hide the implementation of DiscreteFactor and thus avoid
|
||||
* diamond inheritance.
|
||||
*
|
||||
* @ingroup hybrid
|
||||
*/
|
||||
|
|
@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return pointer to the internal discrete factor
|
||||
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@
|
|||
#include <string>
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||
|
|
@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
virtual double error(const HybridValues &values) const = 0;
|
||||
|
||||
/// True if this is a factor of discrete variables only.
|
||||
bool isDiscrete() const { return isDiscrete_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/HessianFactor.h>
|
||||
#include <gtsam/linear/JacobianFactor.h>
|
||||
|
||||
|
|
@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s,
|
|||
inner_->print("\n", formatter);
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridGaussianFactor::error(const HybridValues &values) const {
|
||||
return inner_->error(values.continuous());
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ namespace gtsam {
|
|||
// Forward declarations
|
||||
class JacobianFactor;
|
||||
class HessianFactor;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
||||
|
|
@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return pointer to the internal discrete factor
|
||||
GaussianFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -55,13 +55,14 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianMixtureFactor::Sum &addGaussian(
|
||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||
using Y = GaussianFactorGraph;
|
||||
// If the decision tree is not intiialized, then intialize it.
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (sum.empty()) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
|
|
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
|
|||
|
||||
for (auto &f : factors) {
|
||||
if (f->isHybrid()) {
|
||||
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
sum = cgmf->add(sum);
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
sum = gm->add(sum);
|
||||
}
|
||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
sum = gm->asMixture()->add(sum);
|
||||
|
|
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
const KeySet &continuousSeparator,
|
||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||
// NOTE: since we use the special JunctionTree,
|
||||
// only possiblity is continuous conditioned on discrete.
|
||||
// only possibility is continuous conditioned on discrete.
|
||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||
discreteSeparatorSet.end());
|
||||
|
||||
|
|
@ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
};
|
||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||
|
||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
GaussianMixtureFactor::FactorAndConstant>;
|
||||
|
||||
KeyVector keysOfEliminated; // Not the ordering
|
||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||
|
||||
// This is the elimination method on the leaf nodes
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
||||
-> GaussianFactorGraph::EliminationResult {
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
|
||||
if (graph.empty()) {
|
||||
return {nullptr, nullptr};
|
||||
return {nullptr, {nullptr, 0.0}};
|
||||
}
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
|
|
@ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
boost::shared_ptr<GaussianFactor>>
|
||||
result = EliminatePreferCholesky(graph, frontalKeys);
|
||||
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
|
||||
|
||||
// Initialize the keysOfEliminated to be the keys of the
|
||||
// eliminated GaussianConditional
|
||||
keysOfEliminated = result.first->keys();
|
||||
keysOfSeparator = result.second->keys();
|
||||
keysOfEliminated = conditional_factor.first->keys();
|
||||
keysOfSeparator = conditional_factor.second->keys();
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
gttoc_(hybrid_eliminate);
|
||||
#endif
|
||||
|
||||
return result;
|
||||
return {conditional_factor.first, {conditional_factor.second, 0.0}};
|
||||
};
|
||||
|
||||
// Perform elimination!
|
||||
|
|
@ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
// Separate out decision tree into conditionals and remaining factors.
|
||||
auto pair = unzip(eliminationResults);
|
||||
|
||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
||||
const auto &separatorFactors = pair.second;
|
||||
|
||||
// Create the GaussianMixture from the conditionals
|
||||
auto conditional = boost::make_shared<GaussianMixture>(
|
||||
|
|
@ -257,13 +258,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
// DiscreteFactor, with the error for each discrete choice.
|
||||
if (keysOfSeparator.empty()) {
|
||||
VectorValues empty_values;
|
||||
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
|
||||
auto factorProb =
|
||||
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
||||
GaussianFactor::shared_ptr factor = factor_z.factor;
|
||||
if (!factor) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
double error =
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant());
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant()) +
|
||||
factor_z.constant;
|
||||
return std::exp(-error);
|
||||
}
|
||||
};
|
||||
|
|
@ -452,6 +456,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
|
|
@ -491,38 +496,17 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridGaussianFactorGraph::error(
|
||||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
auto factor = factors_.at(idx);
|
||||
|
||||
if (factor->isHybrid()) {
|
||||
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
error += c->asMixture()->error(continuousValues, discreteValues);
|
||||
}
|
||||
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||
error += f->error(continuousValues, discreteValues);
|
||||
}
|
||||
|
||||
} else if (factor->isContinuous()) {
|
||||
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
error += f->inner()->error(continuousValues);
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
error += cg->asGaussian()->error(continuousValues);
|
||||
}
|
||||
}
|
||||
for (auto &factor : factors_) {
|
||||
error += factor->error(values);
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridGaussianFactorGraph::probPrime(
|
||||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
double error = this->error(continuousValues, discreteValues);
|
||||
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||
double error = this->error(values);
|
||||
// NOTE: The 0.5 term is handled by each factor
|
||||
return std::exp(-error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
/**
|
||||
* @file HybridGaussianFactorGraph.h
|
||||
* @brief Linearized Hybrid factor graph that uses type erasure
|
||||
* @author Fan Jiang, Varun Agrawal
|
||||
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||
* @date Mar 11, 2022
|
||||
*/
|
||||
|
||||
|
|
@ -38,6 +38,7 @@ class HybridBayesTree;
|
|||
class HybridJunctionTree;
|
||||
class DecisionTreeFactor;
|
||||
class JacobianFactor;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* @brief Main elimination function for HybridGaussianFactorGraph.
|
||||
|
|
@ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* @brief Compute error given a continuous vector values
|
||||
* and a discrete assignment.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues
|
||||
* for computing the error.
|
||||
* @param discreteValues The specific discrete assignment
|
||||
* whose error we wish to compute.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues& continuousValues,
|
||||
const DiscreteValues& discreteValues) const;
|
||||
double error(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
|
|
@ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* @brief Compute the unnormalized posterior probability for a continuous
|
||||
* vector values given a specific assignment.
|
||||
*
|
||||
* @param continuousValues The vector values for which to compute the
|
||||
* posterior probability.
|
||||
* @param discreteValues The specific assignment to use for the computation.
|
||||
* @return double
|
||||
*/
|
||||
double probPrime(const VectorValues& continuousValues,
|
||||
const DiscreteValues& discreteValues) const;
|
||||
double probPrime(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
|
|
|
|||
|
|
@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor {
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
NonlinearFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Error for HybridValues is not provided for nonlinear factor.
|
||||
double error(const HybridValues &values) const override {
|
||||
throw std::runtime_error(
|
||||
"HybridNonlinearFactor::error(HybridValues) not implemented.");
|
||||
}
|
||||
|
||||
/// Linearize to a HybridGaussianFactor at the linearization point `c`.
|
||||
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
|
||||
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor {
|
|||
factor, continuousValues);
|
||||
}
|
||||
|
||||
/// Error for HybridValues is not provided for nonlinear hybrid factor.
|
||||
double error(const HybridValues &values) const override {
|
||||
throw std::runtime_error(
|
||||
"MixtureFactor::error(HybridValues) not implemented.");
|
||||
}
|
||||
|
||||
size_t dim() const {
|
||||
// TODO(Varun)
|
||||
throw std::runtime_error("MixtureFactor::dim not implemented.");
|
||||
|
|
|
|||
|
|
@ -183,10 +183,8 @@ class HybridGaussianFactorGraph {
|
|||
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
// evaluation
|
||||
double error(const gtsam::VectorValues& continuousValues,
|
||||
const gtsam::DiscreteValues& discreteValues) const;
|
||||
double probPrime(const gtsam::VectorValues& continuousValues,
|
||||
const gtsam::DiscreteValues& discreteValues) const;
|
||||
double error(const gtsam::HybridValues& values) const;
|
||||
double probPrime(const gtsam::HybridValues& values) const;
|
||||
|
||||
gtsam::HybridBayesNet* eliminateSequential();
|
||||
gtsam::HybridBayesNet* eliminateSequential(
|
||||
|
|
|
|||
|
|
@ -128,9 +128,9 @@ TEST(GaussianMixture, Error) {
|
|||
// Regression for non-tree version.
|
||||
DiscreteValues assignment;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
|
||||
EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8);
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment),
|
||||
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}),
|
||||
1e-8);
|
||||
}
|
||||
|
||||
|
|
@ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) {
|
|||
const GaussianMixtureFactor::Factors factors(
|
||||
gm.conditionals(),
|
||||
[measurements](const GaussianConditional::shared_ptr& conditional) {
|
||||
return conditional->likelihood(measurements);
|
||||
return GaussianMixtureFactor::FactorAndConstant{
|
||||
conditional->likelihood(measurements),
|
||||
conditional->logNormalizationConstant()};
|
||||
});
|
||||
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
||||
EXPECT(assert_equal(*factor, expected));
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
|
|
@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) {
|
|||
DiscreteValues discreteValues;
|
||||
discreteValues[m1.first] = 1;
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9);
|
||||
4.0, mixtureFactor.error({continuousValues, discreteValues}),
|
||||
1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) {
|
|||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
//TODO(Varun) The expectedAssignment should be 111, not 101
|
||||
// TODO(Varun) The expectedAssignment should be 111, not 101
|
||||
DiscreteValues expectedAssignment;
|
||||
expectedAssignment[M(0)] = 1;
|
||||
expectedAssignment[M(1)] = 0;
|
||||
expectedAssignment[M(2)] = 1;
|
||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||
|
||||
//TODO(Varun) This should be all -Vector1::Ones()
|
||||
// TODO(Varun) This should be all -Vector1::Ones()
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
||||
|
|
@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) {
|
|||
double total_error = 0;
|
||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||
if (hybridBayesNet->at(idx)->isHybrid()) {
|
||||
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
|
||||
discrete_values);
|
||||
double error = hybridBayesNet->atMixture(idx)->error(
|
||||
{delta.continuous(), discrete_values});
|
||||
total_error += error;
|
||||
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
||||
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
|
||||
|
|
@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) {
|
|||
}
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
|
||||
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
|
||||
1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
|
|||
continue;
|
||||
}
|
||||
|
||||
double error = graph.error(delta, assignment);
|
||||
double error = graph.error({delta, assignment});
|
||||
probPrimes.push_back(exp(-error));
|
||||
}
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||
|
|
@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
|
|||
const HybridValues& sample) -> double {
|
||||
const DiscreteValues assignment = sample.discrete();
|
||||
// Compute in log form for numerical stability
|
||||
double log_ratio = bayesNet->error(sample.continuous(), assignment) -
|
||||
factorGraph->error(sample.continuous(), assignment);
|
||||
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
|
||||
factorGraph->error({sample.continuous(), assignment});
|
||||
double ratio = exp(-log_ratio);
|
||||
return ratio;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
|||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
double error = graph.error(delta.continuous(), delta.discrete());
|
||||
|
||||
double expected_error = 0.490243199;
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error, 1e-9));
|
||||
|
||||
double probs = exp(-error);
|
||||
double expected_probs = graph.probPrime(delta.continuous(), delta.discrete());
|
||||
const HybridValues delta = hybridBayesNet->optimize();
|
||||
const double error = graph.error(delta);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_probs, probs, 1e-7));
|
||||
EXPECT(assert_equal(1.58886, error, 1e-5));
|
||||
|
||||
// Real test:
|
||||
EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -168,26 +168,30 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
double GaussianConditional::logDeterminant() const {
|
||||
double logDet;
|
||||
if (this->get_model()) {
|
||||
Vector diag = this->R().diagonal();
|
||||
this->get_model()->whitenInPlace(diag);
|
||||
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
|
||||
if (get_model()) {
|
||||
Vector diag = R().diagonal();
|
||||
get_model()->whitenInPlace(diag);
|
||||
return diag.unaryExpr([](double x) { return log(x); }).sum();
|
||||
} else {
|
||||
logDet =
|
||||
this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
|
||||
return R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
|
||||
}
|
||||
return logDet;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma))
|
||||
// log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||
double GaussianConditional::logDensity(const VectorValues& x) const {
|
||||
// normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
// log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||
double GaussianConditional::logNormalizationConstant() const {
|
||||
constexpr double log2pi = 1.8378770664093454835606594728112;
|
||||
size_t n = d().size();
|
||||
// log det(Sigma)) = - 2.0 * logDeterminant()
|
||||
return - error(x) - 0.5 * n * log2pi + logDeterminant();
|
||||
return - 0.5 * n * log2pi + logDeterminant();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// density = k exp(-error(x))
|
||||
// log = log(k) -error(x) - 0.5 * n*log(2*pi)
|
||||
double GaussianConditional::logDensity(const VectorValues& x) const {
|
||||
return logNormalizationConstant() - error(x);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ namespace gtsam {
|
|||
*
|
||||
* @return double
|
||||
*/
|
||||
double determinant() const { return exp(this->logDeterminant()); }
|
||||
inline double determinant() const { return exp(logDeterminant()); }
|
||||
|
||||
/**
|
||||
* @brief Compute the log determinant of the R matrix.
|
||||
|
|
@ -184,6 +184,19 @@ namespace gtsam {
|
|||
*/
|
||||
double logDeterminant() const;
|
||||
|
||||
/**
|
||||
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||
*/
|
||||
double logNormalizationConstant() const;
|
||||
|
||||
/**
|
||||
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
*/
|
||||
inline double normalizationConstant() const {
|
||||
return exp(logNormalizationConstant());
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves a conditional Gaussian and writes the solution into the entries of
|
||||
* \c x for each frontal variable of the conditional. The parents are
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ All Rights Reserved
|
|||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Hybrid Factor Graphs.
|
||||
Author: Fan Jiang
|
||||
Author: Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||
"""
|
||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||
|
||||
|
|
@ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase
|
|||
|
||||
import gtsam
|
||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
||||
GaussianMixture, GaussianMixtureFactor,
|
||||
GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
|
||||
HybridGaussianFactorGraph, JacobianFactor, Ordering,
|
||||
noiseModel)
|
||||
|
||||
|
||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||
"""Unit tests for HybridGaussianFactorGraph."""
|
||||
|
||||
def test_create(self):
|
||||
"""Test construction of hybrid factor graph."""
|
||||
model = noiseModel.Unit.Create(3)
|
||||
|
|
@ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||
|
||||
@staticmethod
|
||||
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
|
||||
def tiny(num_measurements: int = 1) -> HybridBayesNet:
|
||||
"""
|
||||
Create a tiny two variable hybrid model which represents
|
||||
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
|
||||
"""
|
||||
# Create hybrid Bayes net.
|
||||
bayesNet = gtsam.HybridBayesNet()
|
||||
bayesNet = HybridBayesNet()
|
||||
|
||||
# Create mode key: 0 is low-noise, 1 is high-noise.
|
||||
mode = (M(0), 2)
|
||||
|
|
@ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
bayesNet.addGaussian(prior_on_x0)
|
||||
|
||||
# Add prior on mode.
|
||||
bayesNet.emplaceDiscrete(mode, "1/1")
|
||||
bayesNet.emplaceDiscrete(mode, "4/6")
|
||||
|
||||
return bayesNet
|
||||
|
||||
@staticmethod
|
||||
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
|
||||
"""Create a factor graph from the Bayes net with sampled measurements.
|
||||
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
|
||||
and thus represents the same joint probability as the Bayes net.
|
||||
"""
|
||||
fg = HybridGaussianFactorGraph()
|
||||
num_measurements = bayesNet.size() - 2
|
||||
for i in range(num_measurements):
|
||||
conditional = bayesNet.atMixture(i)
|
||||
measurement = gtsam.VectorValues()
|
||||
measurement.insert(Z(i), sample.at(Z(i)))
|
||||
factor = conditional.likelihood(measurement)
|
||||
fg.push_back(factor)
|
||||
fg.push_back(bayesNet.atGaussian(num_measurements))
|
||||
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||
return fg
|
||||
|
||||
@classmethod
|
||||
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000):
|
||||
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
|
||||
# Use prior on x0, mode as proposal density.
|
||||
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
|
||||
|
||||
# Allocate space for marginals.
|
||||
marginals = np.zeros((2,))
|
||||
|
||||
# Do importance sampling.
|
||||
num_measurements = bayesNet.size() - 2
|
||||
for s in range(N):
|
||||
proposed = prior.sample()
|
||||
for i in range(num_measurements):
|
||||
z_i = sample.at(Z(i))
|
||||
proposed.insert(Z(i), z_i)
|
||||
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
|
||||
marginals[proposed.atDiscrete(M(0))] += weight
|
||||
|
||||
# print marginals:
|
||||
marginals /= marginals.sum()
|
||||
return marginals
|
||||
|
||||
def test_tiny(self):
|
||||
"""Test a tiny two variable hybrid model."""
|
||||
bayesNet = self.tiny()
|
||||
sample = bayesNet.sample()
|
||||
# print(sample)
|
||||
|
||||
# Create a factor graph from the Bayes net with sampled measurements.
|
||||
fg = HybridGaussianFactorGraph()
|
||||
conditional = bayesNet.atMixture(0)
|
||||
measurement = gtsam.VectorValues()
|
||||
measurement.insert(Z(0), sample.at(Z(0)))
|
||||
factor = conditional.likelihood(measurement)
|
||||
fg.push_back(factor)
|
||||
fg.push_back(bayesNet.atGaussian(1))
|
||||
fg.push_back(bayesNet.atDiscrete(2))
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(bayesNet, sample)
|
||||
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; z0) = {marginals[0]}")
|
||||
# print(f"P(mode=1; z0) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1)
|
||||
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1)
|
||||
|
||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||
self.assertEqual(fg.size(), 3)
|
||||
|
||||
@staticmethod
|
||||
def calculate_ratio(bayesNet, fg, sample):
|
||||
def calculate_ratio(bayesNet: HybridBayesNet,
|
||||
fg: HybridGaussianFactorGraph,
|
||||
sample: HybridValues):
|
||||
"""Calculate ratio between Bayes net probability and the factor graph."""
|
||||
continuous = gtsam.VectorValues()
|
||||
continuous.insert(X(0), sample.at(X(0)))
|
||||
return bayesNet.evaluate(sample) / fg.probPrime(
|
||||
continuous, sample.discrete())
|
||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
|
||||
|
||||
def test_ratio(self):
|
||||
"""
|
||||
|
|
@ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||
bayesNet = self.tiny(num_measurements=2)
|
||||
# Sample from the Bayes net.
|
||||
sample: gtsam.HybridValues = bayesNet.sample()
|
||||
sample: HybridValues = bayesNet.sample()
|
||||
# print(sample)
|
||||
|
||||
# Create a factor graph from the Bayes net with sampled measurements.
|
||||
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
|
||||
# and thus represents the same joint probability as the Bayes net.
|
||||
fg = HybridGaussianFactorGraph()
|
||||
for i in range(2):
|
||||
conditional = bayesNet.atMixture(i)
|
||||
measurement = gtsam.VectorValues()
|
||||
measurement.insert(Z(i), sample.at(Z(i)))
|
||||
factor = conditional.likelihood(measurement)
|
||||
fg.push_back(factor)
|
||||
fg.push_back(bayesNet.atGaussian(2))
|
||||
fg.push_back(bayesNet.atDiscrete(3))
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(bayesNet, sample)
|
||||
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; z0, z1) = {marginals[0]}")
|
||||
# print(f"P(mode=1; z0, z1) = {marginals[1]}")
|
||||
|
||||
# print(fg)
|
||||
# Check marginals based on sampled mode.
|
||||
if sample.atDiscrete(M(0)) == 0:
|
||||
self.assertGreater(marginals[0], marginals[1])
|
||||
else:
|
||||
self.assertGreater(marginals[1], marginals[0])
|
||||
|
||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||
self.assertEqual(fg.size(), 4)
|
||||
|
||||
# Calculate ratio between Bayes net probability and the factor graph:
|
||||
|
|
@ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
for i in range(10):
|
||||
other = bayesNet.sample()
|
||||
other.update(measurements)
|
||||
# print(other)
|
||||
# ratio = self.calculate_ratio(bayesNet, fg, other)
|
||||
ratio = self.calculate_ratio(bayesNet, fg, other)
|
||||
# print(f"Ratio: {ratio}\n")
|
||||
# self.assertAlmostEqual(ratio, expected_ratio)
|
||||
if (ratio > 0):
|
||||
self.assertAlmostEqual(ratio, expected_ratio)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue