Switch to using HybridValues

release/4.3a0
Frank Dellaert 2022-12-30 12:10:16 -05:00
parent b972be0b8f
commit 9cf3e5c26a
6 changed files with 38 additions and 50 deletions

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -159,9 +160,9 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
} }
/* ************************************************************************* */ /* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
std::set<DiscreteKey> s; std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end()); s.insert(discreteKeys.begin(), discreteKeys.end());
return s; return s;
} }
@ -186,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); const DiscreteValues values(choices);
// Case where the gaussian mixture has the same // Case where the gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
@ -256,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues, double GaussianMixture::error(const HybridValues &values) const {
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues); auto conditional = conditionals_(values.discrete());
return conditional->error(continuousValues); return conditional->error(values.continuous());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -30,6 +30,7 @@
namespace gtsam { namespace gtsam {
class GaussianMixtureFactor; class GaussianMixtureFactor;
class HybridValues;
/** /**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as * @brief A conditional of gaussian mixtures indexed by discrete variables, as
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// Defaut constructor, mainly for serialization. /// Default constructor, mainly for serialization.
GaussianMixture() = default; GaussianMixture() = default;
/** /**
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Standard API /// @name Standard API
/// @{ /// @{
/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()( GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const; const DiscreteValues &discreteValues) const;
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
* @brief Compute the error of this Gaussian Mixture given the continuous * @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment. * values and a discrete assignment.
* *
* @param continuousValues Continuous values at which to compute the error. * @param values Continuous values and discrete assignment.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double * @return double
*/ */
double error(const VectorValues &continuousValues, double error(const HybridValues &values) const;
const DiscreteValues &discreteValues) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
}; };
/// Return the DiscreteKey vector as a set. /// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys); std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits // traits
template <> template <>

View File

@ -1,5 +1,5 @@
/* ---------------------------------------------------------------------------- /* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation, * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415 * Atlanta, Georgia 30332-0415
* All Rights Reserved * All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
@ -12,6 +12,7 @@
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Shangjie Xue * @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022 * @date January 2022
*/ */
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues, double HybridBayesNet::error(const HybridValues &values) const {
const DiscreteValues &discreteValues) const { GaussianBayesNet gbn = choose(values.discrete());
GaussianBayesNet gbn = choose(discreteValues); return gbn.error(values.continuous());
return gbn.error(continuousValues);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief 0.5 * sum of squared Mahalanobis distances * @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment. * for a specific discrete assignment.
* *
* @param continuousValues Continuous values at which to compute the error. * @param values Continuous values and discrete assignment.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double * @return double
*/ */
double error(const VectorValues &continuousValues, double error(const HybridValues &values) const;
const DiscreteValues &discreteValues) const;
/** /**
* @brief Compute conditional error for each discrete assignment, * @brief Compute conditional error for each discrete assignment,

View File

@ -55,13 +55,14 @@
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (sum.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
for (auto &f : factors) { for (auto &f : factors) {
if (f->isHybrid()) { if (f->isHybrid()) {
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { // TODO(dellaert): just use a virtual method defined in HybridFactor.
sum = cgmf->add(sum); if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum);
} }
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum); sum = gm->asMixture()->add(sum);
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
const KeySet &continuousSeparator, const KeySet &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) { const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree, // NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete. // only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
@ -251,8 +253,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults); auto pair = unzip(eliminationResults);
const auto &separatorFactors = pair.second;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto conditional = boost::make_shared<GaussianMixture>(
@ -460,6 +461,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Iterate over each factor. // Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) { for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error; AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) { if (factors_.at(idx)->isHybrid()) {
@ -499,27 +501,26 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
} }
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::error( double HybridGaussianFactorGraph::error(const HybridValues &values) const {
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = 0.0; double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) { for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
auto factor = factors_.at(idx); auto factor = factors_.at(idx);
if (factor->isHybrid()) { if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) { if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues); error += c->asMixture()->error(values);
} }
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) { if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues); error += f->error(values);
} }
} else if (factor->isContinuous()) { } else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) { if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues); error += f->inner()->error(values.continuous());
} }
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) { if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues); error += cg->asGaussian()->error(values.continuous());
} }
} }
} }
@ -527,10 +528,8 @@ double HybridGaussianFactorGraph::error(
} }
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime( double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
const VectorValues &continuousValues, double error = this->error(values);
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
// NOTE: The 0.5 term is handled by each factor // NOTE: The 0.5 term is handled by each factor
return std::exp(-error); return std::exp(-error);
} }

View File

@ -186,14 +186,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute error given a continuous vector values * @brief Compute error given a continuous vector values
* and a discrete assignment. * and a discrete assignment.
* *
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double * @return double
*/ */
double error(const VectorValues& continuousValues, double error(const HybridValues& values) const;
const DiscreteValues& discreteValues) const;
/** /**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
@ -210,13 +205,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute the unnormalized posterior probability for a continuous * @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment. * vector values given a specific assignment.
* *
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double * @return double
*/ */
double probPrime(const VectorValues& continuousValues, double probPrime(const HybridValues& values) const;
const DiscreteValues& discreteValues) const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are