Merge pull request #1318 from borglab/hybrid/error
commit
07d0a031b3
|
@ -71,7 +71,7 @@ namespace gtsam {
|
|||
static inline double id(const double& x) { return x; }
|
||||
};
|
||||
|
||||
AlgebraicDecisionTree() : Base(1.0) {}
|
||||
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||
|
||||
// Explicitly non-explicit constructor
|
||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||
|
@ -158,9 +158,9 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/// print method customized to value type `double`.
|
||||
void print(const std::string& s,
|
||||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
void print(const std::string& s = "",
|
||||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
auto valueFormatter = [](const double& v) {
|
||||
return (boost::format("%4.8g") % v).str();
|
||||
};
|
||||
|
|
|
@ -85,8 +85,8 @@ size_t GaussianMixture::nrComponents() const {
|
|||
|
||||
/* *******************************************************************************/
|
||||
GaussianConditional::shared_ptr GaussianMixture::operator()(
|
||||
const DiscreteValues &discreteVals) const {
|
||||
auto &ptr = conditionals_(discreteVals);
|
||||
const DiscreteValues &discreteValues) const {
|
||||
auto &ptr = conditionals_(discreteValues);
|
||||
if (!ptr) return nullptr;
|
||||
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
|
||||
if (conditional)
|
||||
|
@ -207,4 +207,30 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
|||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to calculate to double error value from GaussianConditional.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues);
|
||||
} else {
|
||||
// Return arbitrarily large error if conditional is null
|
||||
// Conditional is null if it is pruned out.
|
||||
return 1e50;
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(discreteValues);
|
||||
return conditional->error(continuousValues);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -122,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// @{
|
||||
|
||||
GaussianConditional::shared_ptr operator()(
|
||||
const DiscreteValues &discreteVals) const;
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/// Returns the total number of continuous components
|
||||
size_t nrComponents() const;
|
||||
|
@ -144,6 +144,26 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals();
|
||||
|
||||
/**
|
||||
* @brief Compute error of the GaussianMixture as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the conditionals, and leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||
* `decisionTree`.
|
||||
|
|
|
@ -95,4 +95,26 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
|||
};
|
||||
return {factors_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianFactor::shared_ptr &factor) {
|
||||
return factor->error(continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto factor = factors_(discreteValues);
|
||||
return factor->error(continuousValues);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -20,15 +20,19 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianFactorGraph;
|
||||
|
||||
// Needed for wrapper.
|
||||
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
|
||||
|
||||
/**
|
||||
|
@ -126,6 +130,26 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
*/
|
||||
Sum add(const Sum &sum) const;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the factors involved, and leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||
sum = factor.add(sum);
|
||||
|
|
|
@ -232,4 +232,56 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
|||
return gbn.optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
GaussianBayesNet gbn = this->choose(discreteValues);
|
||||
return gbn.error(continuousValues);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree;
|
||||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> conditional_error;
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment and compute error.
|
||||
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
||||
conditional_error = gm->error(continuousValues);
|
||||
|
||||
// Assign for the first index, add error for subsequent ones.
|
||||
if (idx == 0) {
|
||||
error_tree = conditional_error;
|
||||
} else {
|
||||
error_tree = error_tree + conditional_error;
|
||||
}
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
double error = this->atGaussian(idx)->error(continuousValues);
|
||||
// Add the computed error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
return error_tree.apply([](double error) { return exp(-error); });
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -124,6 +124,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves);
|
||||
|
||||
/**
|
||||
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||
* for a specific discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues Discrete assignment for a specific mode sequence.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute conditional error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability q(μ|M),
|
||||
* for each discrete assignment, and return as a tree.
|
||||
* q(μ|M) is the unnormalized probability at the MLE point μ,
|
||||
* conditioned on the discrete variables.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
@ -423,4 +423,58 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
|||
return ordering;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixtureFactor::shared_ptr gaussianMixture =
|
||||
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
|
||||
// Compute factor error.
|
||||
factor_error = gaussianMixture->error(continuousValues);
|
||||
|
||||
// If first factor, assign error, else add it.
|
||||
if (idx == 0) {
|
||||
error_tree = factor_error;
|
||||
} else {
|
||||
error_tree = error_tree + factor_error;
|
||||
}
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
auto hybridGaussianFactor =
|
||||
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
|
||||
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
|
||||
|
||||
// Compute the error of the gaussian factor.
|
||||
double error = gaussian->error(continuousValues);
|
||||
// Add the gaussian factor error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
AlgebraicDecisionTree<Key> prob_tree =
|
||||
error_tree.apply([](double error) { return exp(-error); });
|
||||
return prob_tree;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -41,7 +41,7 @@ class JacobianFactor;
|
|||
|
||||
/**
|
||||
* @brief Main elimination function for HybridGaussianFactorGraph.
|
||||
*
|
||||
*
|
||||
* @param factors The factor graph to eliminate.
|
||||
* @param keys The elimination ordering.
|
||||
* @return The conditional on the ordering keys and the remaining factors.
|
||||
|
@ -99,11 +99,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||
|
||||
using Values = gtsam::Values; ///< backwards compatibility
|
||||
using Indices = KeyVector; ///> map from keys to values
|
||||
using Indices = KeyVector; ///< map from keys to values
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// @brief Default constructor.
|
||||
HybridGaussianFactorGraph() = default;
|
||||
|
||||
/**
|
||||
|
@ -170,6 +171,28 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
*
|
||||
* Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
* for each discrete assignment, and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/nonlinear/Symbol.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor {
|
|||
* elements based on the number of discrete keys and the cardinality of the
|
||||
* keys, so that the decision tree is constructed appropriately.
|
||||
*
|
||||
* @tparam FACTOR The type of the factor shared pointers being passed in. Will
|
||||
* be typecast to NonlinearFactor shared pointers.
|
||||
* @tparam FACTOR The type of the factor shared pointers being passed in.
|
||||
* Will be typecast to NonlinearFactor shared pointers.
|
||||
* @param keys Vector of keys for continuous factors.
|
||||
* @param discreteKeys Vector of discrete keys.
|
||||
* @param factors Vector of shared pointers to factors.
|
||||
* @param factors Vector of nonlinear factors.
|
||||
* @param normalized Flag indicating if the factor error is already
|
||||
* normalized.
|
||||
*/
|
||||
|
@ -107,8 +108,12 @@ class MixtureFactor : public HybridFactor {
|
|||
std::copy(f->keys().begin(), f->keys().end(),
|
||||
std::inserter(factor_keys_set, factor_keys_set.end()));
|
||||
|
||||
nonlinear_factors.push_back(
|
||||
boost::dynamic_pointer_cast<NonlinearFactor>(f));
|
||||
if (auto nf = boost::dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||
nonlinear_factors.push_back(nf);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Factors passed into MixtureFactor need to be nonlinear!");
|
||||
}
|
||||
}
|
||||
factors_ = Factors(discreteKeys, nonlinear_factors);
|
||||
|
||||
|
@ -121,22 +126,39 @@ class MixtureFactor : public HybridFactor {
|
|||
|
||||
~MixtureFactor() = default;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the MixtureFactor as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous values for which to compute the
|
||||
* error.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the factor, and leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const Values& continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [continuousValues](const sharedFactor& factor) {
|
||||
return factor->error(continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute error of factor given both continuous and discrete values.
|
||||
*
|
||||
* @param continuousVals The continuous Values.
|
||||
* @param discreteVals The discrete Values.
|
||||
* @param continuousValues The continuous Values.
|
||||
* @param discreteValues The discrete Values.
|
||||
* @return double The error of this factor.
|
||||
*/
|
||||
double error(const Values& continuousVals,
|
||||
const DiscreteValues& discreteVals) const {
|
||||
// Retrieve the factor corresponding to the assignment in discreteVals.
|
||||
auto factor = factors_(discreteVals);
|
||||
double error(const Values& continuousValues,
|
||||
const DiscreteValues& discreteValues) const {
|
||||
// Retrieve the factor corresponding to the assignment in discreteValues.
|
||||
auto factor = factors_(discreteValues);
|
||||
// Compute the error for the selected factor
|
||||
const double factorError = factor->error(continuousVals);
|
||||
const double factorError = factor->error(continuousValues);
|
||||
if (normalized_) return factorError;
|
||||
return factorError +
|
||||
this->nonlinearFactorLogNormalizingConstant(factor, continuousVals);
|
||||
return factorError + this->nonlinearFactorLogNormalizingConstant(
|
||||
factor, continuousValues);
|
||||
}
|
||||
|
||||
size_t dim() const {
|
||||
|
@ -149,7 +171,7 @@ class MixtureFactor : public HybridFactor {
|
|||
|
||||
/// print to stdout
|
||||
void print(
|
||||
const std::string& s = "MixtureFactor",
|
||||
const std::string& s = "",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
|
||||
std::cout << (s.empty() ? "" : s + " ");
|
||||
Base::print("", keyFormatter);
|
||||
|
@ -192,17 +214,18 @@ class MixtureFactor : public HybridFactor {
|
|||
/// Linearize specific nonlinear factors based on the assignment in
|
||||
/// discreteValues.
|
||||
GaussianFactor::shared_ptr linearize(
|
||||
const Values& continuousVals, const DiscreteValues& discreteVals) const {
|
||||
auto factor = factors_(discreteVals);
|
||||
return factor->linearize(continuousVals);
|
||||
const Values& continuousValues,
|
||||
const DiscreteValues& discreteValues) const {
|
||||
auto factor = factors_(discreteValues);
|
||||
return factor->linearize(continuousValues);
|
||||
}
|
||||
|
||||
/// Linearize all the continuous factors to get a GaussianMixtureFactor.
|
||||
boost::shared_ptr<GaussianMixtureFactor> linearize(
|
||||
const Values& continuousVals) const {
|
||||
const Values& continuousValues) const {
|
||||
// functional to linearize each factor in the decision tree
|
||||
auto linearizeDT = [continuousVals](const sharedFactor& factor) {
|
||||
return factor->linearize(continuousVals);
|
||||
auto linearizeDT = [continuousValues](const sharedFactor& factor) {
|
||||
return factor->linearize(continuousValues);
|
||||
};
|
||||
|
||||
DecisionTree<Key, GaussianFactor::shared_ptr> linearized_factors(
|
||||
|
|
|
@ -196,22 +196,24 @@ class HybridNonlinearFactorGraph {
|
|||
|
||||
#include <gtsam/hybrid/MixtureFactor.h>
|
||||
class MixtureFactor : gtsam::HybridFactor {
|
||||
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false);
|
||||
MixtureFactor(
|
||||
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors,
|
||||
bool normalized = false);
|
||||
|
||||
template <FACTOR = {gtsam::NonlinearFactor}>
|
||||
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||
const std::vector<FACTOR*>& factors,
|
||||
bool normalized = false);
|
||||
|
||||
double error(const gtsam::Values& continuousVals,
|
||||
const gtsam::DiscreteValues& discreteVals) const;
|
||||
double error(const gtsam::Values& continuousValues,
|
||||
const gtsam::DiscreteValues& discreteValues) const;
|
||||
|
||||
double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor,
|
||||
const gtsam::Values& values) const;
|
||||
|
||||
GaussianMixtureFactor* linearize(
|
||||
const gtsam::Values& continuousVals) const;
|
||||
const gtsam::Values& continuousValues) const;
|
||||
|
||||
void print(string s = "MixtureFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
|
|
|
@ -78,15 +78,58 @@ TEST(GaussianMixture, Equals) {
|
|||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
// Let's check that this worked:
|
||||
DiscreteValues mode;
|
||||
mode[m1.first] = 1;
|
||||
auto actual = mixtureFactor(mode);
|
||||
auto actual = mixture(mode);
|
||||
EXPECT(actual == conditional1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Test error method of GaussianMixture.
|
||||
TEST(GaussianMixture, Error) {
|
||||
Matrix22 S1 = Matrix22::Identity();
|
||||
Matrix22 S2 = Matrix22::Identity() * 2;
|
||||
Matrix22 R1 = Matrix22::Ones();
|
||||
Matrix22 R2 = Matrix22::Ones();
|
||||
Vector2 d1(1, 2), d2(2, 1);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
|
||||
X(2), S1, model),
|
||||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||
X(2), S2, model);
|
||||
|
||||
// Create decision tree
|
||||
DiscreteKey m1(M(1), 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
VectorValues values;
|
||||
values.insert(X(1), Vector2::Ones());
|
||||
values.insert(X(2), Vector2::Zero());
|
||||
auto error_tree = mixture.error(values);
|
||||
|
||||
// regression
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> leaves = {0.5, 4.3252595};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||
|
||||
// Regression for non-tree version.
|
||||
DiscreteValues assignment;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file GaussianMixtureFactor.cpp
|
||||
* @file testGaussianMixtureFactor.cpp
|
||||
* @brief Unit tests for GaussianMixtureFactor
|
||||
* @author Varun Agrawal
|
||||
* @author Fan Jiang
|
||||
|
@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
|||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||
}
|
||||
|
||||
TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
|
||||
TEST(GaussianMixtureFactor, GaussianMixture) {
|
||||
KeyVector keys;
|
||||
keys.push_back(X(0));
|
||||
keys.push_back(X(1));
|
||||
|
@ -151,6 +151,46 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
|
|||
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test the error of the GaussianMixtureFactor
|
||||
TEST(GaussianMixtureFactor, Error) {
|
||||
DiscreteKey m1(1, 2);
|
||||
|
||||
auto A01 = Matrix2::Identity();
|
||||
auto A02 = Matrix2::Identity();
|
||||
|
||||
auto A11 = Matrix2::Identity();
|
||||
auto A12 = Matrix2::Identity() * 2;
|
||||
|
||||
auto b = Vector2::Zero();
|
||||
|
||||
auto f0 = boost::make_shared<JacobianFactor>(X(1), A01, X(2), A02, b);
|
||||
auto f1 = boost::make_shared<JacobianFactor>(X(1), A11, X(2), A12, b);
|
||||
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||
|
||||
VectorValues continuousValues;
|
||||
continuousValues.insert(X(1), Vector2(0, 0));
|
||||
continuousValues.insert(X(2), Vector2(1, 1));
|
||||
|
||||
// error should return a tree of errors, with nodes for each discrete value.
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
// Error values for regression test
|
||||
std::vector<double> errors = {1, 4};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
|
||||
|
||||
EXPECT(assert_equal(expected_error, error_tree));
|
||||
|
||||
// Test for single leaf given discrete assignment P(X|M,Z).
|
||||
DiscreteValues discreteValues;
|
||||
discreteValues[m1.first] = 1;
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -183,6 +183,60 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net error
|
||||
TEST(HybridBayesNet, Error) {
|
||||
Switching s(3);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = hybridBayesNet->error(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {0.0097568009, 3.3973404e-31, 0.029126214,
|
||||
0.0097568009};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
|
||||
|
||||
// Error on pruned bayes net
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
|
||||
|
||||
std::vector<double> pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009};
|
||||
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
||||
pruned_leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
|
||||
|
||||
// Verify error computation and check for specific error value
|
||||
DiscreteValues discrete_values;
|
||||
boost::assign::insert(discrete_values)(M(0), 1)(M(1), 1);
|
||||
|
||||
double total_error = 0;
|
||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||
if (hybridBayesNet->at(idx)->isHybrid()) {
|
||||
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
|
||||
discrete_values);
|
||||
total_error += error;
|
||||
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
||||
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
|
||||
total_error += error;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
|
||||
1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
|
|
|
@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) {
|
|||
EXPECT(assert_equal(expected_discrete, result.discrete()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||
Switching s(3);
|
||||
|
||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||
|
||||
Ordering hybridOrdering = graph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = graph.error(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||
|
||||
auto probs = graph.probPrime(delta.continuous());
|
||||
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
|
||||
0.99029064};
|
||||
AlgebraicDecisionTree<Key> expected_probs(discrete_keys, prob_leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_probs, probs, 1e-7));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testMixtureFactor.cpp
|
||||
* @brief Unit tests for MixtureFactor
|
||||
* @author Varun Agrawal
|
||||
* @date October 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/MixtureFactor.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/slam/BetweenFactor.h>
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check iterators of empty mixture.
|
||||
TEST(MixtureFactor, Constructor) {
|
||||
MixtureFactor factor;
|
||||
MixtureFactor::const_iterator const_it = factor.begin();
|
||||
CHECK(const_it == factor.end());
|
||||
MixtureFactor::iterator it = factor.begin();
|
||||
CHECK(it == factor.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test .print() output.
|
||||
TEST(MixtureFactor, Printing) {
|
||||
DiscreteKey m1(1, 2);
|
||||
double between0 = 0.0;
|
||||
double between1 = 1.0;
|
||||
|
||||
Vector1 sigmas = Vector1(1.0);
|
||||
auto model = noiseModel::Diagonal::Sigmas(sigmas, false);
|
||||
|
||||
auto f0 =
|
||||
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between0, model);
|
||||
auto f1 =
|
||||
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
|
||||
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||
|
||||
std::string expected =
|
||||
R"(Hybrid [x1 x2; 1]
|
||||
MixtureFactor
|
||||
Choice(1)
|
||||
0 Leaf Nonlinear factor on 2 keys
|
||||
1 Leaf Nonlinear factor on 2 keys
|
||||
)";
|
||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test the error of the MixtureFactor
|
||||
TEST(MixtureFactor, Error) {
|
||||
DiscreteKey m1(1, 2);
|
||||
|
||||
double between0 = 0.0;
|
||||
double between1 = 1.0;
|
||||
|
||||
Vector1 sigmas = Vector1(1.0);
|
||||
auto model = noiseModel::Diagonal::Sigmas(sigmas, false);
|
||||
|
||||
auto f0 =
|
||||
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between0, model);
|
||||
auto f1 =
|
||||
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
|
||||
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||
|
||||
Values continuousValues;
|
||||
continuousValues.insert<double>(X(1), 0);
|
||||
continuousValues.insert<double>(X(2), 1);
|
||||
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> errors = {0.5, 0};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
|
||||
|
||||
EXPECT(assert_equal(expected_error, error_tree));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
Loading…
Reference in New Issue