Switch to using HybridValues
parent
b972be0b8f
commit
9cf3e5c26a
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -159,9 +160,9 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
std::set<DiscreteKey> s;
|
std::set<DiscreteKey> s;
|
||||||
s.insert(dkeys.begin(), dkeys.end());
|
s.insert(discreteKeys.begin(), discreteKeys.end());
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -186,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianConditional::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
const DiscreteValues values(choices);
|
||||||
|
|
||||||
// Case where the gaussian mixture has the same
|
// Case where the gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
|
|
@ -256,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const VectorValues &continuousValues,
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
const DiscreteValues &discreteValues) const {
|
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(discreteValues);
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->error(continuousValues);
|
return conditional->error(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class GaussianMixtureFactor;
|
class GaussianMixtureFactor;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
||||||
|
|
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Defaut constructor, mainly for serialization.
|
/// Default constructor, mainly for serialization.
|
||||||
GaussianMixture() = default;
|
GaussianMixture() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// @name Standard API
|
/// @name Standard API
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
||||||
GaussianConditional::shared_ptr operator()(
|
GaussianConditional::shared_ptr operator()(
|
||||||
const DiscreteValues &discreteValues) const;
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
|
|
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||||
* values and a discrete assignment.
|
* values and a discrete assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues &continuousValues,
|
double error(const HybridValues &values) const;
|
||||||
const DiscreteValues &discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
|
|
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the DiscreteKey vector as a set.
|
/// Return the DiscreteKey vector as a set.
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
/* ----------------------------------------------------------------------------
|
/* ----------------------------------------------------------------------------
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
|
||||||
* Atlanta, Georgia 30332-0415
|
* Atlanta, Georgia 30332-0415
|
||||||
* All Rights Reserved
|
* All Rights Reserved
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Shangjie Xue
|
* @author Shangjie Xue
|
||||||
|
* @author Frank Dellaert
|
||||||
* @date January 2022
|
* @date January 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
double HybridBayesNet::error(const HybridValues &values) const {
|
||||||
const DiscreteValues &discreteValues) const {
|
GaussianBayesNet gbn = choose(values.discrete());
|
||||||
GaussianBayesNet gbn = choose(discreteValues);
|
return gbn.error(values.continuous());
|
||||||
return gbn.error(continuousValues);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief 0.5 * sum of squared Mahalanobis distances
|
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||||
* for a specific discrete assignment.
|
* for a specific discrete assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @param discreteValues Discrete assignment for a specific mode sequence.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues &continuousValues,
|
double error(const HybridValues &values) const;
|
||||||
const DiscreteValues &discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute conditional error for each discrete assignment,
|
* @brief Compute conditional error for each discrete assignment,
|
||||||
|
|
|
||||||
|
|
@ -55,13 +55,14 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianMixtureFactor::Sum &addGaussian(
|
static GaussianMixtureFactor::Sum &addGaussian(
|
||||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianFactorGraph;
|
||||||
// If the decision tree is not intiialized, then intialize it.
|
// If the decision tree is not initialized, then initialize it.
|
||||||
if (sum.empty()) {
|
if (sum.empty()) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
|
|
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (f->isHybrid()) {
|
if (f->isHybrid()) {
|
||||||
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
sum = cgmf->add(sum);
|
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
sum = gm->add(sum);
|
||||||
}
|
}
|
||||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
sum = gm->asMixture()->add(sum);
|
sum = gm->asMixture()->add(sum);
|
||||||
|
|
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const KeySet &continuousSeparator,
|
const KeySet &continuousSeparator,
|
||||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||||
// NOTE: since we use the special JunctionTree,
|
// NOTE: since we use the special JunctionTree,
|
||||||
// only possiblity is continuous conditioned on discrete.
|
// only possibility is continuous conditioned on discrete.
|
||||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||||
discreteSeparatorSet.end());
|
discreteSeparatorSet.end());
|
||||||
|
|
||||||
|
|
@ -251,8 +253,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
// Separate out decision tree into conditionals and remaining factors.
|
// Separate out decision tree into conditionals and remaining factors.
|
||||||
auto pair = unzip(eliminationResults);
|
auto pair = unzip(eliminationResults);
|
||||||
|
const auto &separatorFactors = pair.second;
|
||||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
|
||||||
|
|
||||||
// Create the GaussianMixture from the conditionals
|
// Create the GaussianMixture from the conditionals
|
||||||
auto conditional = boost::make_shared<GaussianMixture>(
|
auto conditional = boost::make_shared<GaussianMixture>(
|
||||||
|
|
@ -460,6 +461,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
|
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
AlgebraicDecisionTree<Key> factor_error;
|
AlgebraicDecisionTree<Key> factor_error;
|
||||||
|
|
||||||
if (factors_.at(idx)->isHybrid()) {
|
if (factors_.at(idx)->isHybrid()) {
|
||||||
|
|
@ -499,27 +501,26 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridGaussianFactorGraph::error(
|
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||||
const VectorValues &continuousValues,
|
|
||||||
const DiscreteValues &discreteValues) const {
|
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
auto factor = factors_.at(idx);
|
auto factor = factors_.at(idx);
|
||||||
|
|
||||||
if (factor->isHybrid()) {
|
if (factor->isHybrid()) {
|
||||||
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
error += c->asMixture()->error(continuousValues, discreteValues);
|
error += c->asMixture()->error(values);
|
||||||
}
|
}
|
||||||
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||||
error += f->error(continuousValues, discreteValues);
|
error += f->error(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (factor->isContinuous()) {
|
} else if (factor->isContinuous()) {
|
||||||
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||||
error += f->inner()->error(continuousValues);
|
error += f->inner()->error(values.continuous());
|
||||||
}
|
}
|
||||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
error += cg->asGaussian()->error(continuousValues);
|
error += cg->asGaussian()->error(values.continuous());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -527,10 +528,8 @@ double HybridGaussianFactorGraph::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridGaussianFactorGraph::probPrime(
|
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||||
const VectorValues &continuousValues,
|
double error = this->error(values);
|
||||||
const DiscreteValues &discreteValues) const {
|
|
||||||
double error = this->error(continuousValues, discreteValues);
|
|
||||||
// NOTE: The 0.5 term is handled by each factor
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
return std::exp(-error);
|
return std::exp(-error);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -186,14 +186,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
* @brief Compute error given a continuous vector values
|
* @brief Compute error given a continuous vector values
|
||||||
* and a discrete assignment.
|
* and a discrete assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues The continuous VectorValues
|
|
||||||
* for computing the error.
|
|
||||||
* @param discreteValues The specific discrete assignment
|
|
||||||
* whose error we wish to compute.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const VectorValues& continuousValues,
|
double error(const HybridValues& values) const;
|
||||||
const DiscreteValues& discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||||
|
|
@ -210,13 +205,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
* @brief Compute the unnormalized posterior probability for a continuous
|
* @brief Compute the unnormalized posterior probability for a continuous
|
||||||
* vector values given a specific assignment.
|
* vector values given a specific assignment.
|
||||||
*
|
*
|
||||||
* @param continuousValues The vector values for which to compute the
|
|
||||||
* posterior probability.
|
|
||||||
* @param discreteValues The specific assignment to use for the computation.
|
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double probPrime(const VectorValues& continuousValues,
|
double probPrime(const HybridValues& values) const;
|
||||||
const DiscreteValues& discreteValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue