Merge pull request #1358 from borglab/hybrid/gaussian-mixture-factor
commit
d0821a57de
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const GaussianMixtureFactor::Factors likelihoods(
|
const GaussianMixtureFactor::Factors likelihoods(
|
||||||
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
|
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
return conditional->likelihood(frontals);
|
return GaussianMixtureFactor::FactorAndConstant{
|
||||||
|
conditional->likelihood(frontals),
|
||||||
|
conditional->logNormalizationConstant()};
|
||||||
});
|
});
|
||||||
return boost::make_shared<GaussianMixtureFactor>(
|
return boost::make_shared<GaussianMixtureFactor>(
|
||||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
std::set<DiscreteKey> s;
|
std::set<DiscreteKey> s;
|
||||||
s.insert(dkeys.begin(), dkeys.end());
|
s.insert(discreteKeys.begin(), discreteKeys.end());
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianConditional::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
const DiscreteValues values(choices);
|
||||||
|
|
||||||
// Case where the gaussian mixture has the same
|
// Case where the gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
|
|
@ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const VectorValues &continuousValues,
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
const DiscreteValues &discreteValues) const {
|
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(discreteValues);
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->error(continuousValues);
|
return conditional->error(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class GaussianMixtureFactor;
|
class GaussianMixtureFactor;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
||||||
|
|
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Defaut constructor, mainly for serialization.
|
/// Default constructor, mainly for serialization.
|
||||||
GaussianMixture() = default;
|
GaussianMixture() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// @name Standard API
|
/// @name Standard API
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
||||||
GaussianConditional::shared_ptr operator()(
|
GaussianConditional::shared_ptr operator()(
|
||||||
const DiscreteValues &discreteValues) const;
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
|
|
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||||
* values and a discrete assignment.
|
* values and a discrete assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues &continuousValues,
|
double error(const HybridValues &values) const override;
|
||||||
const DiscreteValues &discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
|
|
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the DiscreteKey vector as a set.
|
/// Return the DiscreteKey vector as a set.
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -29,8 +31,11 @@ namespace gtsam {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const Factors &factors)
|
const Mixture &factors)
|
||||||
: Base(continuousKeys, discreteKeys), factors_(factors) {}
|
: Base(continuousKeys, discreteKeys),
|
||||||
|
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
|
||||||
|
return FactorAndConstant{gf, 0.0};
|
||||||
|
}) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
|
|
@ -43,10 +48,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return Base::equals(*e, tol) &&
|
return Base::equals(*e, tol) &&
|
||||||
factors_.equals(e->factors_,
|
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
|
||||||
[tol](const GaussianFactor::shared_ptr &f1,
|
const FactorAndConstant &f2) {
|
||||||
const GaussianFactor::shared_ptr &f2) {
|
return f1.factor->equals(*(f2.factor), tol) &&
|
||||||
return f1->equals(*f2, tol);
|
std::abs(f1.constant - f2.constant) < tol;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
|
[&](const FactorAndConstant &gf_z) -> std::string {
|
||||||
|
auto gf = gf_z.factor;
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
if (gf && !gf->empty()) {
|
if (gf && !gf->empty()) {
|
||||||
|
|
@ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
|
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
|
||||||
return factors_;
|
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
|
||||||
|
return factor_z.factor;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
@ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
|
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor_z.factor);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
|
|
@ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc =
|
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
|
||||||
[continuousValues](const GaussianFactor::shared_ptr &factor) {
|
return factor_z.error(continuousValues);
|
||||||
return factor->error(continuousValues);
|
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||||
return errorTree;
|
return errorTree;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixtureFactor::error(
|
double GaussianMixtureFactor::error(const HybridValues &values) const {
|
||||||
const VectorValues &continuousValues,
|
const FactorAndConstant factor_z = factors_(values.discrete());
|
||||||
const DiscreteValues &discreteValues) const {
|
return factor_z.error(values.continuous());
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
|
||||||
auto factor = factors_(discreteValues);
|
|
||||||
return factor->error(continuousValues);
|
|
||||||
}
|
}
|
||||||
|
/* *******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -23,17 +23,15 @@
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class GaussianFactorGraph;
|
class GaussianFactorGraph;
|
||||||
|
class HybridValues;
|
||||||
// Needed for wrapper.
|
class DiscreteValues;
|
||||||
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
|
class VectorValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Implementation of a discrete conditional mixture factor.
|
* @brief Implementation of a discrete conditional mixture factor.
|
||||||
|
|
@ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||||
|
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian Factors
|
/// Gaussian factor and log of normalizing constant.
|
||||||
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
struct FactorAndConstant {
|
||||||
|
sharedFactor factor;
|
||||||
|
double constant;
|
||||||
|
|
||||||
|
// Return error with constant correction.
|
||||||
|
double error(const VectorValues &values) const {
|
||||||
|
// Note minus sign: constant is log of normalization constant for probabilities.
|
||||||
|
// Errors is the negative log-likelihood, hence we subtract the constant here.
|
||||||
|
return factor->error(values) - constant;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check pointer equality.
|
||||||
|
bool operator==(const FactorAndConstant &other) const {
|
||||||
|
return factor == other.factor && constant == other.constant;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||||
|
using Factors = DecisionTree<Key, FactorAndConstant>;
|
||||||
|
using Mixture = DecisionTree<Key, sharedFactor>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
|
|
@ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* @param continuousKeys A vector of keys representing continuous variables.
|
* @param continuousKeys A vector of keys representing continuous variables.
|
||||||
* @param discreteKeys A vector of keys representing discrete variables and
|
* @param discreteKeys A vector of keys representing discrete variables and
|
||||||
* their cardinalities.
|
* their cardinalities.
|
||||||
* @param factors The decision tree of Gaussian Factors stored as the mixture
|
* @param factors The decision tree of Gaussian factors stored as the mixture
|
||||||
* density.
|
* density.
|
||||||
*/
|
*/
|
||||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const Factors &factors);
|
const Mixture &factors);
|
||||||
|
|
||||||
|
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
|
const DiscreteKeys &discreteKeys,
|
||||||
|
const Factors &factors_and_z)
|
||||||
|
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
||||||
|
|
@ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
*/
|
*/
|
||||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
const std::vector<sharedFactor> &factors)
|
||||||
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
||||||
Factors(discreteKeys, factors)) {}
|
Mixture(discreteKeys, factors)) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
const std::string &s = "GaussianMixtureFactor\n",
|
const std::string &s = "GaussianMixtureFactor\n",
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Getter for the underlying Gaussian Factor Decision Tree.
|
/// Getter for the underlying Gaussian Factor Decision Tree.
|
||||||
const Factors &factors();
|
const Mixture factors() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||||
|
|
@ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||||
* values and a discrete assignment.
|
|
||||||
*
|
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
|
||||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues &continuousValues,
|
double error(const HybridValues &values) const override;
|
||||||
const DiscreteValues &discreteValues) const;
|
|
||||||
|
|
||||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||||
sum = factor.add(sum);
|
sum = factor.add(sum);
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
/* ----------------------------------------------------------------------------
|
/* ----------------------------------------------------------------------------
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
|
||||||
* Atlanta, Georgia 30332-0415
|
* Atlanta, Georgia 30332-0415
|
||||||
* All Rights Reserved
|
* All Rights Reserved
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Shangjie Xue
|
* @author Shangjie Xue
|
||||||
|
* @author Frank Dellaert
|
||||||
* @date January 2022
|
* @date January 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
double HybridBayesNet::error(const HybridValues &values) const {
|
||||||
const DiscreteValues &discreteValues) const {
|
GaussianBayesNet gbn = choose(values.discrete());
|
||||||
GaussianBayesNet gbn = choose(discreteValues);
|
return gbn.error(values.continuous());
|
||||||
return gbn.error(continuousValues);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief 0.5 * sum of squared Mahalanobis distances
|
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||||
* for a specific discrete assignment.
|
* for a specific discrete assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @param discreteValues Discrete assignment for a specific mode sequence.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues &continuousValues,
|
double error(const HybridValues &values) const;
|
||||||
const DiscreteValues &discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute conditional error for each discrete assignment,
|
* @brief Compute conditional error for each discrete assignment,
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ namespace gtsam {
|
||||||
* having diamond inheritances, and neutralized the need to change other
|
* having diamond inheritances, and neutralized the need to change other
|
||||||
* components of GTSAM to make hybrid elimination work.
|
* components of GTSAM to make hybrid elimination work.
|
||||||
*
|
*
|
||||||
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
|
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
|
||||||
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
|
|
@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional
|
||||||
*/
|
*/
|
||||||
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return HybridConditional as a GaussianMixture
|
|
||||||
* @return nullptr if not a mixture
|
|
||||||
* @return GaussianMixture::shared_ptr otherwise
|
|
||||||
*/
|
|
||||||
GaussianMixture::shared_ptr asMixture() {
|
|
||||||
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return HybridConditional as a GaussianConditional
|
|
||||||
* @return nullptr if not a GaussianConditional
|
|
||||||
* @return GaussianConditional::shared_ptr otherwise
|
|
||||||
*/
|
|
||||||
GaussianConditional::shared_ptr asGaussian() {
|
|
||||||
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return conditional as a DiscreteConditional
|
|
||||||
* @return nullptr if not a DiscreteConditional
|
|
||||||
* @return DiscreteConditional::shared_ptr
|
|
||||||
*/
|
|
||||||
DiscreteConditional::shared_ptr asDiscrete() {
|
|
||||||
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional
|
||||||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return HybridConditional as a GaussianMixture
|
||||||
|
* @return nullptr if not a mixture
|
||||||
|
* @return GaussianMixture::shared_ptr otherwise
|
||||||
|
*/
|
||||||
|
GaussianMixture::shared_ptr asMixture() const {
|
||||||
|
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return HybridConditional as a GaussianConditional
|
||||||
|
* @return nullptr if not a GaussianConditional
|
||||||
|
* @return GaussianConditional::shared_ptr otherwise
|
||||||
|
*/
|
||||||
|
GaussianConditional::shared_ptr asGaussian() const {
|
||||||
|
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return conditional as a DiscreteConditional
|
||||||
|
* @return nullptr if not a DiscreteConditional
|
||||||
|
* @return DiscreteConditional::shared_ptr
|
||||||
|
*/
|
||||||
|
DiscreteConditional::shared_ptr asDiscrete() const {
|
||||||
|
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||||
|
|
||||||
|
/// Return the error of the underlying conditional.
|
||||||
|
/// Currently only implemented for Gaussian mixture.
|
||||||
|
double error(const HybridValues& values) const override {
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->error(values);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::error: only implemented for Gaussian mixture");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
|
||||||
|
|
@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s,
|
||||||
inner_->print("\n", formatter);
|
inner_->print("\n", formatter);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridDiscreteFactor::error(const HybridValues &values) const {
|
||||||
|
return -log((*inner_)(values.discrete()));
|
||||||
|
}
|
||||||
|
/* ************************************************************************ */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,12 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
|
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
|
||||||
* us to hide the implementation of DiscreteFactor and thus avoid diamond
|
* allows us to hide the implementation of DiscreteFactor and thus avoid
|
||||||
* inheritance.
|
* diamond inheritance.
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
|
|
@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Return pointer to the internal discrete factor
|
/// Return pointer to the internal discrete factor
|
||||||
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
||||||
|
|
||||||
|
/// Return the error of the underlying Discrete Factor.
|
||||||
|
double error(const HybridValues &values) const override;
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||||
|
|
@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||||
|
* values and a discrete assignment.
|
||||||
|
*
|
||||||
|
* @param values Continuous values and discrete assignment.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
virtual double error(const HybridValues &values) const = 0;
|
||||||
|
|
||||||
/// True if this is a factor of discrete variables only.
|
/// True if this is a factor of discrete variables only.
|
||||||
bool isDiscrete() const { return isDiscrete_; }
|
bool isDiscrete() const { return isDiscrete_; }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/linear/HessianFactor.h>
|
#include <gtsam/linear/HessianFactor.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
|
|
@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
inner_->print("\n", formatter);
|
inner_->print("\n", formatter);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridGaussianFactor::error(const HybridValues &values) const {
|
||||||
|
return inner_->error(values.continuous());
|
||||||
|
}
|
||||||
|
/* ************************************************************************ */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ namespace gtsam {
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
class JacobianFactor;
|
class JacobianFactor;
|
||||||
class HessianFactor;
|
class HessianFactor;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
||||||
|
|
@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Return pointer to the internal discrete factor
|
||||||
GaussianFactor::shared_ptr inner() const { return inner_; }
|
GaussianFactor::shared_ptr inner() const { return inner_; }
|
||||||
|
|
||||||
|
/// Return the error of the underlying Discrete Factor.
|
||||||
|
double error(const HybridValues &values) const override;
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -55,13 +55,14 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianMixtureFactor::Sum &addGaussian(
|
static GaussianMixtureFactor::Sum &addGaussian(
|
||||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianFactorGraph;
|
||||||
// If the decision tree is not intiialized, then intialize it.
|
// If the decision tree is not initialized, then initialize it.
|
||||||
if (sum.empty()) {
|
if (sum.empty()) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
|
|
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (f->isHybrid()) {
|
if (f->isHybrid()) {
|
||||||
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
sum = cgmf->add(sum);
|
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
sum = gm->add(sum);
|
||||||
}
|
}
|
||||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
sum = gm->asMixture()->add(sum);
|
sum = gm->asMixture()->add(sum);
|
||||||
|
|
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const KeySet &continuousSeparator,
|
const KeySet &continuousSeparator,
|
||||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||||
// NOTE: since we use the special JunctionTree,
|
// NOTE: since we use the special JunctionTree,
|
||||||
// only possiblity is continuous conditioned on discrete.
|
// only possibility is continuous conditioned on discrete.
|
||||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||||
discreteSeparatorSet.end());
|
discreteSeparatorSet.end());
|
||||||
|
|
||||||
|
|
@ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
};
|
};
|
||||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||||
|
|
||||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
|
GaussianMixtureFactor::FactorAndConstant>;
|
||||||
|
|
||||||
KeyVector keysOfEliminated; // Not the ordering
|
KeyVector keysOfEliminated; // Not the ordering
|
||||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
|
||||||
-> GaussianFactorGraph::EliminationResult {
|
|
||||||
if (graph.empty()) {
|
if (graph.empty()) {
|
||||||
return {nullptr, nullptr};
|
return {nullptr, {nullptr, 0.0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
|
|
@ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
boost::shared_ptr<GaussianFactor>>
|
boost::shared_ptr<GaussianFactor>>
|
||||||
result = EliminatePreferCholesky(graph, frontalKeys);
|
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
|
||||||
|
|
||||||
// Initialize the keysOfEliminated to be the keys of the
|
// Initialize the keysOfEliminated to be the keys of the
|
||||||
// eliminated GaussianConditional
|
// eliminated GaussianConditional
|
||||||
keysOfEliminated = result.first->keys();
|
keysOfEliminated = conditional_factor.first->keys();
|
||||||
keysOfSeparator = result.second->keys();
|
keysOfSeparator = conditional_factor.second->keys();
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
gttoc_(hybrid_eliminate);
|
gttoc_(hybrid_eliminate);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return result;
|
return {conditional_factor.first, {conditional_factor.second, 0.0}};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
|
|
@ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
// Separate out decision tree into conditionals and remaining factors.
|
// Separate out decision tree into conditionals and remaining factors.
|
||||||
auto pair = unzip(eliminationResults);
|
auto pair = unzip(eliminationResults);
|
||||||
|
const auto &separatorFactors = pair.second;
|
||||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
|
||||||
|
|
||||||
// Create the GaussianMixture from the conditionals
|
// Create the GaussianMixture from the conditionals
|
||||||
auto conditional = boost::make_shared<GaussianMixture>(
|
auto conditional = boost::make_shared<GaussianMixture>(
|
||||||
|
|
@ -257,13 +258,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// DiscreteFactor, with the error for each discrete choice.
|
// DiscreteFactor, with the error for each discrete choice.
|
||||||
if (keysOfSeparator.empty()) {
|
if (keysOfSeparator.empty()) {
|
||||||
VectorValues empty_values;
|
VectorValues empty_values;
|
||||||
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
|
auto factorProb =
|
||||||
|
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
||||||
|
GaussianFactor::shared_ptr factor = factor_z.factor;
|
||||||
if (!factor) {
|
if (!factor) {
|
||||||
return 0.0; // If nullptr, return 0.0 probability
|
return 0.0; // If nullptr, return 0.0 probability
|
||||||
} else {
|
} else {
|
||||||
// This is the probability q(μ) at the MLE point.
|
// This is the probability q(μ) at the MLE point.
|
||||||
double error =
|
double error =
|
||||||
0.5 * std::abs(factor->augmentedInformation().determinant());
|
0.5 * std::abs(factor->augmentedInformation().determinant()) +
|
||||||
|
factor_z.constant;
|
||||||
return std::exp(-error);
|
return std::exp(-error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -452,6 +456,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
|
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
AlgebraicDecisionTree<Key> factor_error;
|
AlgebraicDecisionTree<Key> factor_error;
|
||||||
|
|
||||||
if (factors_.at(idx)->isHybrid()) {
|
if (factors_.at(idx)->isHybrid()) {
|
||||||
|
|
@ -491,38 +496,17 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridGaussianFactorGraph::error(
|
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||||
const VectorValues &continuousValues,
|
|
||||||
const DiscreteValues &discreteValues) const {
|
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (auto &factor : factors_) {
|
||||||
auto factor = factors_.at(idx);
|
error += factor->error(values);
|
||||||
|
|
||||||
if (factor->isHybrid()) {
|
|
||||||
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
|
||||||
error += c->asMixture()->error(continuousValues, discreteValues);
|
|
||||||
}
|
|
||||||
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
|
||||||
error += f->error(continuousValues, discreteValues);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (factor->isContinuous()) {
|
|
||||||
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
|
||||||
error += f->inner()->error(continuousValues);
|
|
||||||
}
|
|
||||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
|
||||||
error += cg->asGaussian()->error(continuousValues);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridGaussianFactorGraph::probPrime(
|
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||||
const VectorValues &continuousValues,
|
double error = this->error(values);
|
||||||
const DiscreteValues &discreteValues) const {
|
|
||||||
double error = this->error(continuousValues, discreteValues);
|
|
||||||
// NOTE: The 0.5 term is handled by each factor
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
return std::exp(-error);
|
return std::exp(-error);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
/**
|
/**
|
||||||
* @file HybridGaussianFactorGraph.h
|
* @file HybridGaussianFactorGraph.h
|
||||||
* @brief Linearized Hybrid factor graph that uses type erasure
|
* @brief Linearized Hybrid factor graph that uses type erasure
|
||||||
* @author Fan Jiang, Varun Agrawal
|
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
@ -38,6 +38,7 @@ class HybridBayesTree;
|
||||||
class HybridJunctionTree;
|
class HybridJunctionTree;
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
class JacobianFactor;
|
class JacobianFactor;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Main elimination function for HybridGaussianFactorGraph.
|
* @brief Main elimination function for HybridGaussianFactorGraph.
|
||||||
|
|
@ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
* @brief Compute error given a continuous vector values
|
* @brief Compute error given a continuous vector values
|
||||||
* and a discrete assignment.
|
* and a discrete assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues The continuous VectorValues
|
|
||||||
* for computing the error.
|
|
||||||
* @param discreteValues The specific discrete assignment
|
|
||||||
* whose error we wish to compute.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues& continuousValues,
|
double error(const HybridValues& values) const;
|
||||||
const DiscreteValues& discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||||
|
|
@ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
* @brief Compute the unnormalized posterior probability for a continuous
|
* @brief Compute the unnormalized posterior probability for a continuous
|
||||||
* vector values given a specific assignment.
|
* vector values given a specific assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues The vector values for which to compute the
|
|
||||||
* posterior probability.
|
|
||||||
* @param discreteValues The specific assignment to use for the computation.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double probPrime(const VectorValues& continuousValues,
|
double probPrime(const HybridValues& values) const;
|
||||||
const DiscreteValues& discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||||
|
|
|
||||||
|
|
@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor {
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
NonlinearFactor::shared_ptr inner() const { return inner_; }
|
NonlinearFactor::shared_ptr inner() const { return inner_; }
|
||||||
|
|
||||||
|
/// Error for HybridValues is not provided for nonlinear factor.
|
||||||
|
double error(const HybridValues &values) const override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridNonlinearFactor::error(HybridValues) not implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
/// Linearize to a HybridGaussianFactor at the linearization point `c`.
|
/// Linearize to a HybridGaussianFactor at the linearization point `c`.
|
||||||
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
|
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
|
||||||
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
|
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor {
|
||||||
factor, continuousValues);
|
factor, continuousValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Error for HybridValues is not provided for nonlinear hybrid factor.
|
||||||
|
double error(const HybridValues &values) const override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"MixtureFactor::error(HybridValues) not implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
size_t dim() const {
|
size_t dim() const {
|
||||||
// TODO(Varun)
|
// TODO(Varun)
|
||||||
throw std::runtime_error("MixtureFactor::dim not implemented.");
|
throw std::runtime_error("MixtureFactor::dim not implemented.");
|
||||||
|
|
|
||||||
|
|
@ -183,10 +183,8 @@ class HybridGaussianFactorGraph {
|
||||||
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
|
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
// evaluation
|
// evaluation
|
||||||
double error(const gtsam::VectorValues& continuousValues,
|
double error(const gtsam::HybridValues& values) const;
|
||||||
const gtsam::DiscreteValues& discreteValues) const;
|
double probPrime(const gtsam::HybridValues& values) const;
|
||||||
double probPrime(const gtsam::VectorValues& continuousValues,
|
|
||||||
const gtsam::DiscreteValues& discreteValues) const;
|
|
||||||
|
|
||||||
gtsam::HybridBayesNet* eliminateSequential();
|
gtsam::HybridBayesNet* eliminateSequential();
|
||||||
gtsam::HybridBayesNet* eliminateSequential(
|
gtsam::HybridBayesNet* eliminateSequential(
|
||||||
|
|
|
||||||
|
|
@ -128,9 +128,9 @@ TEST(GaussianMixture, Error) {
|
||||||
// Regression for non-tree version.
|
// Regression for non-tree version.
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
assignment[M(1)] = 0;
|
assignment[M(1)] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8);
|
||||||
assignment[M(1)] = 1;
|
assignment[M(1)] = 1;
|
||||||
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment),
|
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}),
|
||||||
1e-8);
|
1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) {
|
||||||
const GaussianMixtureFactor::Factors factors(
|
const GaussianMixtureFactor::Factors factors(
|
||||||
gm.conditionals(),
|
gm.conditionals(),
|
||||||
[measurements](const GaussianConditional::shared_ptr& conditional) {
|
[measurements](const GaussianConditional::shared_ptr& conditional) {
|
||||||
return conditional->likelihood(measurements);
|
return GaussianMixtureFactor::FactorAndConstant{
|
||||||
|
conditional->likelihood(measurements),
|
||||||
|
conditional->logNormalizationConstant()};
|
||||||
});
|
});
|
||||||
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
||||||
EXPECT(assert_equal(*factor, expected));
|
EXPECT(assert_equal(*factor, expected));
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) {
|
||||||
DiscreteValues discreteValues;
|
DiscreteValues discreteValues;
|
||||||
discreteValues[m1.first] = 1;
|
discreteValues[m1.first] = 1;
|
||||||
EXPECT_DOUBLES_EQUAL(
|
EXPECT_DOUBLES_EQUAL(
|
||||||
4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9);
|
4.0, mixtureFactor.error({continuousValues, discreteValues}),
|
||||||
|
1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
|
||||||
//TODO(Varun) The expectedAssignment should be 111, not 101
|
// TODO(Varun) The expectedAssignment should be 111, not 101
|
||||||
DiscreteValues expectedAssignment;
|
DiscreteValues expectedAssignment;
|
||||||
expectedAssignment[M(0)] = 1;
|
expectedAssignment[M(0)] = 1;
|
||||||
expectedAssignment[M(1)] = 0;
|
expectedAssignment[M(1)] = 0;
|
||||||
expectedAssignment[M(2)] = 1;
|
expectedAssignment[M(2)] = 1;
|
||||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||||
|
|
||||||
//TODO(Varun) This should be all -Vector1::Ones()
|
// TODO(Varun) This should be all -Vector1::Ones()
|
||||||
VectorValues expectedValues;
|
VectorValues expectedValues;
|
||||||
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
||||||
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
||||||
|
|
@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) {
|
||||||
double total_error = 0;
|
double total_error = 0;
|
||||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||||
if (hybridBayesNet->at(idx)->isHybrid()) {
|
if (hybridBayesNet->at(idx)->isHybrid()) {
|
||||||
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
|
double error = hybridBayesNet->atMixture(idx)->error(
|
||||||
discrete_values);
|
{delta.continuous(), discrete_values});
|
||||||
total_error += error;
|
total_error += error;
|
||||||
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
||||||
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
|
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
|
||||||
|
|
@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_DOUBLES_EQUAL(
|
EXPECT_DOUBLES_EQUAL(
|
||||||
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
|
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
|
||||||
1e-9);
|
1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
||||||
|
|
|
||||||
|
|
@ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
double error = graph.error(delta, assignment);
|
double error = graph.error({delta, assignment});
|
||||||
probPrimes.push_back(exp(-error));
|
probPrimes.push_back(exp(-error));
|
||||||
}
|
}
|
||||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||||
|
|
@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
const HybridValues& sample) -> double {
|
const HybridValues& sample) -> double {
|
||||||
const DiscreteValues assignment = sample.discrete();
|
const DiscreteValues assignment = sample.discrete();
|
||||||
// Compute in log form for numerical stability
|
// Compute in log form for numerical stability
|
||||||
double log_ratio = bayesNet->error(sample.continuous(), assignment) -
|
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
|
||||||
factorGraph->error(sample.continuous(), assignment);
|
factorGraph->error({sample.continuous(), assignment});
|
||||||
double ratio = exp(-log_ratio);
|
double ratio = exp(-log_ratio);
|
||||||
return ratio;
|
return ratio;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
graph.eliminateSequential(hybridOrdering);
|
graph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
const HybridValues delta = hybridBayesNet->optimize();
|
||||||
double error = graph.error(delta.continuous(), delta.discrete());
|
const double error = graph.error(delta);
|
||||||
|
|
||||||
double expected_error = 0.490243199;
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(expected_error, error, 1e-9));
|
|
||||||
|
|
||||||
double probs = exp(-error);
|
|
||||||
double expected_probs = graph.probPrime(delta.continuous(), delta.discrete());
|
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT(assert_equal(expected_probs, probs, 1e-7));
|
EXPECT(assert_equal(1.58886, error, 1e-5));
|
||||||
|
|
||||||
|
// Real test:
|
||||||
|
EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -168,26 +168,30 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double GaussianConditional::logDeterminant() const {
|
double GaussianConditional::logDeterminant() const {
|
||||||
double logDet;
|
if (get_model()) {
|
||||||
if (this->get_model()) {
|
Vector diag = R().diagonal();
|
||||||
Vector diag = this->R().diagonal();
|
get_model()->whitenInPlace(diag);
|
||||||
this->get_model()->whitenInPlace(diag);
|
return diag.unaryExpr([](double x) { return log(x); }).sum();
|
||||||
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
|
|
||||||
} else {
|
} else {
|
||||||
logDet =
|
return R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
|
||||||
this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
|
|
||||||
}
|
}
|
||||||
return logDet;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma))
|
// normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||||
// log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
// log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||||
double GaussianConditional::logDensity(const VectorValues& x) const {
|
double GaussianConditional::logNormalizationConstant() const {
|
||||||
constexpr double log2pi = 1.8378770664093454835606594728112;
|
constexpr double log2pi = 1.8378770664093454835606594728112;
|
||||||
size_t n = d().size();
|
size_t n = d().size();
|
||||||
// log det(Sigma)) = - 2.0 * logDeterminant()
|
// log det(Sigma)) = - 2.0 * logDeterminant()
|
||||||
return - error(x) - 0.5 * n * log2pi + logDeterminant();
|
return - 0.5 * n * log2pi + logDeterminant();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// density = k exp(-error(x))
|
||||||
|
// log = log(k) -error(x) - 0.5 * n*log(2*pi)
|
||||||
|
double GaussianConditional::logDensity(const VectorValues& x) const {
|
||||||
|
return logNormalizationConstant() - error(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -169,7 +169,7 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double determinant() const { return exp(this->logDeterminant()); }
|
inline double determinant() const { return exp(logDeterminant()); }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the log determinant of the R matrix.
|
* @brief Compute the log determinant of the R matrix.
|
||||||
|
|
@ -184,6 +184,19 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
double logDeterminant() const;
|
double logDeterminant() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||||
|
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||||
|
*/
|
||||||
|
double logNormalizationConstant() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||||
|
*/
|
||||||
|
inline double normalizationConstant() const {
|
||||||
|
return exp(logNormalizationConstant());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solves a conditional Gaussian and writes the solution into the entries of
|
* Solves a conditional Gaussian and writes the solution into the entries of
|
||||||
* \c x for each frontal variable of the conditional. The parents are
|
* \c x for each frontal variable of the conditional. The parents are
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ All Rights Reserved
|
||||||
See LICENSE for the license information
|
See LICENSE for the license information
|
||||||
|
|
||||||
Unit tests for Hybrid Factor Graphs.
|
Unit tests for Hybrid Factor Graphs.
|
||||||
Author: Fan Jiang
|
Author: Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
|
|
@ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
||||||
GaussianMixture, GaussianMixtureFactor,
|
GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
|
||||||
HybridGaussianFactorGraph, JacobianFactor, Ordering,
|
HybridGaussianFactorGraph, JacobianFactor, Ordering,
|
||||||
noiseModel)
|
noiseModel)
|
||||||
|
|
||||||
|
|
||||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
"""Unit tests for HybridGaussianFactorGraph."""
|
"""Unit tests for HybridGaussianFactorGraph."""
|
||||||
|
|
||||||
def test_create(self):
|
def test_create(self):
|
||||||
"""Test construction of hybrid factor graph."""
|
"""Test construction of hybrid factor graph."""
|
||||||
model = noiseModel.Unit.Create(3)
|
model = noiseModel.Unit.Create(3)
|
||||||
|
|
@ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
|
def tiny(num_measurements: int = 1) -> HybridBayesNet:
|
||||||
"""
|
"""
|
||||||
Create a tiny two variable hybrid model which represents
|
Create a tiny two variable hybrid model which represents
|
||||||
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
|
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
|
||||||
"""
|
"""
|
||||||
# Create hybrid Bayes net.
|
# Create hybrid Bayes net.
|
||||||
bayesNet = gtsam.HybridBayesNet()
|
bayesNet = HybridBayesNet()
|
||||||
|
|
||||||
# Create mode key: 0 is low-noise, 1 is high-noise.
|
# Create mode key: 0 is low-noise, 1 is high-noise.
|
||||||
mode = (M(0), 2)
|
mode = (M(0), 2)
|
||||||
|
|
@ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
bayesNet.addGaussian(prior_on_x0)
|
bayesNet.addGaussian(prior_on_x0)
|
||||||
|
|
||||||
# Add prior on mode.
|
# Add prior on mode.
|
||||||
bayesNet.emplaceDiscrete(mode, "1/1")
|
bayesNet.emplaceDiscrete(mode, "4/6")
|
||||||
|
|
||||||
return bayesNet
|
return bayesNet
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
|
||||||
|
"""Create a factor graph from the Bayes net with sampled measurements.
|
||||||
|
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
|
||||||
|
and thus represents the same joint probability as the Bayes net.
|
||||||
|
"""
|
||||||
|
fg = HybridGaussianFactorGraph()
|
||||||
|
num_measurements = bayesNet.size() - 2
|
||||||
|
for i in range(num_measurements):
|
||||||
|
conditional = bayesNet.atMixture(i)
|
||||||
|
measurement = gtsam.VectorValues()
|
||||||
|
measurement.insert(Z(i), sample.at(Z(i)))
|
||||||
|
factor = conditional.likelihood(measurement)
|
||||||
|
fg.push_back(factor)
|
||||||
|
fg.push_back(bayesNet.atGaussian(num_measurements))
|
||||||
|
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||||
|
return fg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000):
|
||||||
|
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
|
||||||
|
# Use prior on x0, mode as proposal density.
|
||||||
|
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
|
||||||
|
|
||||||
|
# Allocate space for marginals.
|
||||||
|
marginals = np.zeros((2,))
|
||||||
|
|
||||||
|
# Do importance sampling.
|
||||||
|
num_measurements = bayesNet.size() - 2
|
||||||
|
for s in range(N):
|
||||||
|
proposed = prior.sample()
|
||||||
|
for i in range(num_measurements):
|
||||||
|
z_i = sample.at(Z(i))
|
||||||
|
proposed.insert(Z(i), z_i)
|
||||||
|
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
|
||||||
|
marginals[proposed.atDiscrete(M(0))] += weight
|
||||||
|
|
||||||
|
# print marginals:
|
||||||
|
marginals /= marginals.sum()
|
||||||
|
return marginals
|
||||||
|
|
||||||
def test_tiny(self):
|
def test_tiny(self):
|
||||||
"""Test a tiny two variable hybrid model."""
|
"""Test a tiny two variable hybrid model."""
|
||||||
bayesNet = self.tiny()
|
bayesNet = self.tiny()
|
||||||
sample = bayesNet.sample()
|
sample = bayesNet.sample()
|
||||||
# print(sample)
|
# print(sample)
|
||||||
|
|
||||||
# Create a factor graph from the Bayes net with sampled measurements.
|
# Estimate marginals using importance sampling.
|
||||||
fg = HybridGaussianFactorGraph()
|
marginals = self.estimate_marginals(bayesNet, sample)
|
||||||
conditional = bayesNet.atMixture(0)
|
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||||
measurement = gtsam.VectorValues()
|
# print(f"P(mode=0; z0) = {marginals[0]}")
|
||||||
measurement.insert(Z(0), sample.at(Z(0)))
|
# print(f"P(mode=1; z0) = {marginals[1]}")
|
||||||
factor = conditional.likelihood(measurement)
|
|
||||||
fg.push_back(factor)
|
|
||||||
fg.push_back(bayesNet.atGaussian(1))
|
|
||||||
fg.push_back(bayesNet.atDiscrete(2))
|
|
||||||
|
|
||||||
|
# Check that the estimate is close to the true value.
|
||||||
|
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1)
|
||||||
|
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1)
|
||||||
|
|
||||||
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
self.assertEqual(fg.size(), 3)
|
self.assertEqual(fg.size(), 3)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calculate_ratio(bayesNet, fg, sample):
|
def calculate_ratio(bayesNet: HybridBayesNet,
|
||||||
|
fg: HybridGaussianFactorGraph,
|
||||||
|
sample: HybridValues):
|
||||||
"""Calculate ratio between Bayes net probability and the factor graph."""
|
"""Calculate ratio between Bayes net probability and the factor graph."""
|
||||||
continuous = gtsam.VectorValues()
|
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
|
||||||
continuous.insert(X(0), sample.at(X(0)))
|
|
||||||
return bayesNet.evaluate(sample) / fg.probPrime(
|
|
||||||
continuous, sample.discrete())
|
|
||||||
|
|
||||||
def test_ratio(self):
|
def test_ratio(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||||
bayesNet = self.tiny(num_measurements=2)
|
bayesNet = self.tiny(num_measurements=2)
|
||||||
# Sample from the Bayes net.
|
# Sample from the Bayes net.
|
||||||
sample: gtsam.HybridValues = bayesNet.sample()
|
sample: HybridValues = bayesNet.sample()
|
||||||
# print(sample)
|
# print(sample)
|
||||||
|
|
||||||
# Create a factor graph from the Bayes net with sampled measurements.
|
# Estimate marginals using importance sampling.
|
||||||
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
|
marginals = self.estimate_marginals(bayesNet, sample)
|
||||||
# and thus represents the same joint probability as the Bayes net.
|
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||||
fg = HybridGaussianFactorGraph()
|
# print(f"P(mode=0; z0, z1) = {marginals[0]}")
|
||||||
for i in range(2):
|
# print(f"P(mode=1; z0, z1) = {marginals[1]}")
|
||||||
conditional = bayesNet.atMixture(i)
|
|
||||||
measurement = gtsam.VectorValues()
|
|
||||||
measurement.insert(Z(i), sample.at(Z(i)))
|
|
||||||
factor = conditional.likelihood(measurement)
|
|
||||||
fg.push_back(factor)
|
|
||||||
fg.push_back(bayesNet.atGaussian(2))
|
|
||||||
fg.push_back(bayesNet.atDiscrete(3))
|
|
||||||
|
|
||||||
# print(fg)
|
# Check marginals based on sampled mode.
|
||||||
|
if sample.atDiscrete(M(0)) == 0:
|
||||||
|
self.assertGreater(marginals[0], marginals[1])
|
||||||
|
else:
|
||||||
|
self.assertGreater(marginals[1], marginals[0])
|
||||||
|
|
||||||
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
self.assertEqual(fg.size(), 4)
|
self.assertEqual(fg.size(), 4)
|
||||||
|
|
||||||
# Calculate ratio between Bayes net probability and the factor graph:
|
# Calculate ratio between Bayes net probability and the factor graph:
|
||||||
|
|
@ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
other = bayesNet.sample()
|
other = bayesNet.sample()
|
||||||
other.update(measurements)
|
other.update(measurements)
|
||||||
# print(other)
|
ratio = self.calculate_ratio(bayesNet, fg, other)
|
||||||
# ratio = self.calculate_ratio(bayesNet, fg, other)
|
|
||||||
# print(f"Ratio: {ratio}\n")
|
# print(f"Ratio: {ratio}\n")
|
||||||
# self.assertAlmostEqual(ratio, expected_ratio)
|
if (ratio > 0):
|
||||||
|
self.assertAlmostEqual(ratio, expected_ratio)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue