Merge pull request #1318 from borglab/hybrid/error

release/4.3a0
Varun Agrawal 2022-12-23 00:22:38 -05:00 committed by GitHub
commit 07d0a031b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 594 additions and 39 deletions

View File

@ -71,7 +71,7 @@ namespace gtsam {
static inline double id(const double& x) { return x; } static inline double id(const double& x) { return x; }
}; };
AlgebraicDecisionTree() : Base(1.0) {} AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
// Explicitly non-explicit constructor // Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {} AlgebraicDecisionTree(const Base& add) : Base(add) {}
@ -158,7 +158,7 @@ namespace gtsam {
} }
/// print method customized to value type `double`. /// print method customized to value type `double`.
void print(const std::string& s, void print(const std::string& s = "",
const typename Base::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {

View File

@ -85,8 +85,8 @@ size_t GaussianMixture::nrComponents() const {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr GaussianMixture::operator()( GaussianConditional::shared_ptr GaussianMixture::operator()(
const DiscreteValues &discreteVals) const { const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteVals); auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr; if (!ptr) return nullptr;
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr); auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional) if (conditional)
@ -207,4 +207,30 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;
} }
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
} else {
// Return arbitrarily large error if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
}
} // namespace gtsam } // namespace gtsam

View File

@ -122,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
/// @{ /// @{
GaussianConditional::shared_ptr operator()( GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteVals) const; const DiscreteValues &discreteValues) const;
/// Returns the total number of continuous components /// Returns the total number of continuous components
size_t nrComponents() const; size_t nrComponents() const;
@ -144,6 +144,26 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree /// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals(); const Conditionals &conditionals();
/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `decisionTree`.

View File

@ -95,4 +95,26 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
}; };
return {factors_, wrap}; return {factors_, wrap};
} }
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
}
} // namespace gtsam } // namespace gtsam

View File

@ -20,15 +20,19 @@
#pragma once #pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
namespace gtsam { namespace gtsam {
class GaussianFactorGraph; class GaussianFactorGraph;
// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
/** /**
@ -126,6 +130,26 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/ */
Sum add(const Sum &sum) const; Sum add(const Sum &sum) const;
/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);

View File

@ -232,4 +232,56 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize(); return gbn.optimize();
} }
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
return gbn.error(continuousValues);
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);
// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}
return error_tree;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
return error_tree.apply([](double error) { return exp(-error); });
}
} // namespace gtsam } // namespace gtsam

View File

@ -124,6 +124,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves); HybridBayesNet prune(size_t maxNrLeaves);
/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
/**
* @brief Compute conditional error for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute unnormalized probability q(μ|M),
* for each discrete assignment, and return as a tree.
* q(μ|M) is the unnormalized probability at the MLE point μ,
* conditioned on the discrete variables.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues &continuousValues) const;
/// @} /// @}
private: private:

View File

@ -423,4 +423,58 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
return ordering; return ordering;
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
// Compute factor error.
factor_error = gaussianMixture->error(continuousValues);
// If first factor, assign error, else add it.
if (idx == 0) {
error_tree = factor_error;
} else {
error_tree = error_tree + factor_error;
}
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
// Compute the error of the gaussian factor.
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}
return error_tree;
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
return prob_tree;
}
} // namespace gtsam } // namespace gtsam

View File

@ -99,11 +99,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
using Indices = KeyVector; ///> map from keys to values using Indices = KeyVector; ///< map from keys to values
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// @brief Default constructor.
HybridGaussianFactorGraph() = default; HybridGaussianFactorGraph() = default;
/** /**
@ -170,6 +171,28 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
} }
} }
/**
* @brief Compute error for each discrete assignment,
* and return as a tree.
*
* Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys. * eliminated after the continuous keys.

View File

@ -23,6 +23,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h> #include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/Symbol.h> #include <gtsam/nonlinear/Symbol.h>
#include <algorithm> #include <algorithm>
@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor {
* elements based on the number of discrete keys and the cardinality of the * elements based on the number of discrete keys and the cardinality of the
* keys, so that the decision tree is constructed appropriately. * keys, so that the decision tree is constructed appropriately.
* *
* @tparam FACTOR The type of the factor shared pointers being passed in. Will * @tparam FACTOR The type of the factor shared pointers being passed in.
* be typecast to NonlinearFactor shared pointers. * Will be typecast to NonlinearFactor shared pointers.
* @param keys Vector of keys for continuous factors. * @param keys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys. * @param discreteKeys Vector of discrete keys.
* @param factors Vector of shared pointers to factors. * @param factors Vector of nonlinear factors.
* @param normalized Flag indicating if the factor error is already * @param normalized Flag indicating if the factor error is already
* normalized. * normalized.
*/ */
@ -107,8 +108,12 @@ class MixtureFactor : public HybridFactor {
std::copy(f->keys().begin(), f->keys().end(), std::copy(f->keys().begin(), f->keys().end(),
std::inserter(factor_keys_set, factor_keys_set.end())); std::inserter(factor_keys_set, factor_keys_set.end()));
nonlinear_factors.push_back( if (auto nf = boost::dynamic_pointer_cast<NonlinearFactor>(f)) {
boost::dynamic_pointer_cast<NonlinearFactor>(f)); nonlinear_factors.push_back(nf);
} else {
throw std::runtime_error(
"Factors passed into MixtureFactor need to be nonlinear!");
}
} }
factors_ = Factors(discreteKeys, nonlinear_factors); factors_ = Factors(discreteKeys, nonlinear_factors);
@ -121,22 +126,39 @@ class MixtureFactor : public HybridFactor {
~MixtureFactor() = default; ~MixtureFactor() = default;
/**
* @brief Compute error of the MixtureFactor as a tree.
*
* @param continuousValues The continuous values for which to compute the
* error.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factor, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const Values& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const sharedFactor& factor) {
return factor->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
/** /**
* @brief Compute error of factor given both continuous and discrete values. * @brief Compute error of factor given both continuous and discrete values.
* *
* @param continuousVals The continuous Values. * @param continuousValues The continuous Values.
* @param discreteVals The discrete Values. * @param discreteValues The discrete Values.
* @return double The error of this factor. * @return double The error of this factor.
*/ */
double error(const Values& continuousVals, double error(const Values& continuousValues,
const DiscreteValues& discreteVals) const { const DiscreteValues& discreteValues) const {
// Retrieve the factor corresponding to the assignment in discreteVals. // Retrieve the factor corresponding to the assignment in discreteValues.
auto factor = factors_(discreteVals); auto factor = factors_(discreteValues);
// Compute the error for the selected factor // Compute the error for the selected factor
const double factorError = factor->error(continuousVals); const double factorError = factor->error(continuousValues);
if (normalized_) return factorError; if (normalized_) return factorError;
return factorError + return factorError + this->nonlinearFactorLogNormalizingConstant(
this->nonlinearFactorLogNormalizingConstant(factor, continuousVals); factor, continuousValues);
} }
size_t dim() const { size_t dim() const {
@ -149,7 +171,7 @@ class MixtureFactor : public HybridFactor {
/// print to stdout /// print to stdout
void print( void print(
const std::string& s = "MixtureFactor", const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
std::cout << (s.empty() ? "" : s + " "); std::cout << (s.empty() ? "" : s + " ");
Base::print("", keyFormatter); Base::print("", keyFormatter);
@ -192,17 +214,18 @@ class MixtureFactor : public HybridFactor {
/// Linearize specific nonlinear factors based on the assignment in /// Linearize specific nonlinear factors based on the assignment in
/// discreteValues. /// discreteValues.
GaussianFactor::shared_ptr linearize( GaussianFactor::shared_ptr linearize(
const Values& continuousVals, const DiscreteValues& discreteVals) const { const Values& continuousValues,
auto factor = factors_(discreteVals); const DiscreteValues& discreteValues) const {
return factor->linearize(continuousVals); auto factor = factors_(discreteValues);
return factor->linearize(continuousValues);
} }
/// Linearize all the continuous factors to get a GaussianMixtureFactor. /// Linearize all the continuous factors to get a GaussianMixtureFactor.
boost::shared_ptr<GaussianMixtureFactor> linearize( boost::shared_ptr<GaussianMixtureFactor> linearize(
const Values& continuousVals) const { const Values& continuousValues) const {
// functional to linearize each factor in the decision tree // functional to linearize each factor in the decision tree
auto linearizeDT = [continuousVals](const sharedFactor& factor) { auto linearizeDT = [continuousValues](const sharedFactor& factor) {
return factor->linearize(continuousVals); return factor->linearize(continuousValues);
}; };
DecisionTree<Key, GaussianFactor::shared_ptr> linearized_factors( DecisionTree<Key, GaussianFactor::shared_ptr> linearized_factors(

View File

@ -196,22 +196,24 @@ class HybridNonlinearFactorGraph {
#include <gtsam/hybrid/MixtureFactor.h> #include <gtsam/hybrid/MixtureFactor.h>
class MixtureFactor : gtsam::HybridFactor { class MixtureFactor : gtsam::HybridFactor {
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, MixtureFactor(
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false); const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors,
bool normalized = false);
template <FACTOR = {gtsam::NonlinearFactor}> template <FACTOR = {gtsam::NonlinearFactor}>
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const std::vector<FACTOR*>& factors, const std::vector<FACTOR*>& factors,
bool normalized = false); bool normalized = false);
double error(const gtsam::Values& continuousVals, double error(const gtsam::Values& continuousValues,
const gtsam::DiscreteValues& discreteVals) const; const gtsam::DiscreteValues& discreteValues) const;
double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor, double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor,
const gtsam::Values& values) const; const gtsam::Values& values) const;
GaussianMixtureFactor* linearize( GaussianMixtureFactor* linearize(
const gtsam::Values& continuousVals) const; const gtsam::Values& continuousValues) const;
void print(string s = "MixtureFactor\n", void print(string s = "MixtureFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =

View File

@ -78,15 +78,58 @@ TEST(GaussianMixture, Equals) {
GaussianMixture::Conditionals conditionals( GaussianMixture::Conditionals conditionals(
{m1}, {m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals); GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
// Let's check that this worked: // Let's check that this worked:
DiscreteValues mode; DiscreteValues mode;
mode[m1.first] = 1; mode[m1.first] = 1;
auto actual = mixtureFactor(mode); auto actual = mixture(mode);
EXPECT(actual == conditional1); EXPECT(actual == conditional1);
} }
/* ************************************************************************* */
/// Test error method of GaussianMixture.
TEST(GaussianMixture, Error) {
Matrix22 S1 = Matrix22::Identity();
Matrix22 S2 = Matrix22::Identity() * 2;
Matrix22 R1 = Matrix22::Ones();
Matrix22 R2 = Matrix22::Ones();
Vector2 d1(1, 2), d2(2, 1);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
X(2), S1, model),
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);
// Create decision tree
DiscreteKey m1(M(1), 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
VectorValues values;
values.insert(X(1), Vector2::Ones());
values.insert(X(2), Vector2::Zero());
auto error_tree = mixture.error(values);
// regression
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> leaves = {0.5, 4.3252595};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
// Regression for non-tree version.
DiscreteValues assignment;
assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file GaussianMixtureFactor.cpp * @file testGaussianMixtureFactor.cpp
* @brief Unit tests for GaussianMixtureFactor * @brief Unit tests for GaussianMixtureFactor
* @author Varun Agrawal * @author Varun Agrawal
* @author Fan Jiang * @author Fan Jiang
@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) {
EXPECT(assert_print_equal(expected, mixtureFactor)); EXPECT(assert_print_equal(expected, mixtureFactor));
} }
TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { TEST(GaussianMixtureFactor, GaussianMixture) {
KeyVector keys; KeyVector keys;
keys.push_back(X(0)); keys.push_back(X(0));
keys.push_back(X(1)); keys.push_back(X(1));
@ -151,6 +151,46 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size()); EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
} }
/* ************************************************************************* */
// Test the error of the GaussianMixtureFactor
TEST(GaussianMixtureFactor, Error) {
DiscreteKey m1(1, 2);
auto A01 = Matrix2::Identity();
auto A02 = Matrix2::Identity();
auto A11 = Matrix2::Identity();
auto A12 = Matrix2::Identity() * 2;
auto b = Vector2::Zero();
auto f0 = boost::make_shared<JacobianFactor>(X(1), A01, X(2), A02, b);
auto f1 = boost::make_shared<JacobianFactor>(X(1), A11, X(2), A12, b);
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
VectorValues continuousValues;
continuousValues.insert(X(1), Vector2(0, 0));
continuousValues.insert(X(2), Vector2(1, 1));
// error should return a tree of errors, with nodes for each discrete value.
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
std::vector<DiscreteKey> discrete_keys = {m1};
// Error values for regression test
std::vector<double> errors = {1, 4};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
EXPECT(assert_equal(expected_error, error_tree));
// Test for single leaf given discrete assignment P(X|M,Z).
DiscreteValues discreteValues;
discreteValues[m1.first] = 1;
EXPECT_DOUBLES_EQUAL(
4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -183,6 +183,60 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
} }
/* ****************************************************************************/
// Test bayes net error
TEST(HybridBayesNet, Error) {
Switching s(3);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.0097568009, 3.3973404e-31, 0.029126214,
0.0097568009};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
// Error on pruned bayes net
auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
std::vector<double> pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves);
// regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
// Verify error computation and check for specific error value
DiscreteValues discrete_values;
boost::assign::insert(discrete_values)(M(0), 1)(M(1), 1);
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
discrete_values);
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
total_error += error;
}
}
EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test bayes net pruning // Test bayes net pruning
TEST(HybridBayesNet, Prune) { TEST(HybridBayesNet, Prune) {

View File

@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) {
EXPECT(assert_equal(expected_discrete, result.discrete())); EXPECT(assert_equal(expected_discrete, result.discrete()));
} }
/* ****************************************************************************/
// Test hybrid gaussian factor graph error and unnormalized probabilities
TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
Switching s(3);
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
auto probs = graph.probPrime(delta.continuous());
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
0.99029064};
AlgebraicDecisionTree<Key> expected_probs(discrete_keys, prob_leaves);
// regression
EXPECT(assert_equal(expected_probs, probs, 1e-7));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -0,0 +1,109 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testMixtureFactor.cpp
* @brief Unit tests for MixtureFactor
* @author Varun Agrawal
* @date October 2022
*/
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/MixtureFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/slam/BetweenFactor.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
/* ************************************************************************* */
// Check iterators of empty mixture.
TEST(MixtureFactor, Constructor) {
MixtureFactor factor;
MixtureFactor::const_iterator const_it = factor.begin();
CHECK(const_it == factor.end());
MixtureFactor::iterator it = factor.begin();
CHECK(it == factor.end());
}
/* ************************************************************************* */
// Test .print() output.
TEST(MixtureFactor, Printing) {
DiscreteKey m1(1, 2);
double between0 = 0.0;
double between1 = 1.0;
Vector1 sigmas = Vector1(1.0);
auto model = noiseModel::Diagonal::Sigmas(sigmas, false);
auto f0 =
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between0, model);
auto f1 =
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
std::string expected =
R"(Hybrid [x1 x2; 1]
MixtureFactor
Choice(1)
0 Leaf Nonlinear factor on 2 keys
1 Leaf Nonlinear factor on 2 keys
)";
EXPECT(assert_print_equal(expected, mixtureFactor));
}
/* ************************************************************************* */
// Test the error of the MixtureFactor
TEST(MixtureFactor, Error) {
DiscreteKey m1(1, 2);
double between0 = 0.0;
double between1 = 1.0;
Vector1 sigmas = Vector1(1.0);
auto model = noiseModel::Diagonal::Sigmas(sigmas, false);
auto f0 =
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between0, model);
auto f1 =
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
Values continuousValues;
continuousValues.insert<double>(X(1), 0);
continuousValues.insert<double>(X(2), 1);
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> errors = {0.5, 0};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
EXPECT(assert_equal(expected_error, error_tree));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */