Add GaussianMixtureFactor::error method and unit test

release/4.3a0
Varun Agrawal 2022-10-31 15:36:48 -04:00
parent 64744b057e
commit c41b58fc98
3 changed files with 50 additions and 11 deletions

View File

@ -95,4 +95,16 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
};
return {factors_, wrap};
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousVals) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousVals);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
} // namespace gtsam

View File

@ -20,6 +20,7 @@
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
@ -131,17 +132,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousVals The continuous VectorValues.
* @return DecisionTree<Key, double> A decision tree with corresponding keys
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
* as the factor but leaf values as the error.
*/
DecisionTree<Key, double> error(const VectorValues &c) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [c](const GaussianFactor::shared_ptr &factor) {
return factor->error(c);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */
/**
* @file GaussianMixtureFactor.cpp
* @file testGaussianMixtureFactor.cpp
* @brief Unit tests for GaussianMixtureFactor
* @author Varun Agrawal
* @author Fan Jiang
@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) {
EXPECT(assert_print_equal(expected, mixtureFactor));
}
TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
TEST(GaussianMixtureFactor, GaussianMixture) {
KeyVector keys;
keys.push_back(X(0));
keys.push_back(X(1));
@ -151,6 +151,39 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
}
/* ************************************************************************* */
// Test the error of the GaussianMixtureFactor
TEST(GaussianMixtureFactor, Error) {
DiscreteKey m1(1, 2);
auto A01 = Matrix2::Identity();
auto A02 = Matrix2::Identity();
auto A11 = Matrix2::Identity();
auto A12 = Matrix2::Identity() * 2;
auto b = Vector2::Zero();
auto f0 = boost::make_shared<JacobianFactor>(X(1), A01, X(2), A02, b);
auto f1 = boost::make_shared<JacobianFactor>(X(1), A11, X(2), A12, b);
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
VectorValues continuousVals;
continuousVals.insert(X(1), Vector2(0, 0));
continuousVals.insert(X(2), Vector2(1, 1));
// error should return a tree of errors, with nodes for each discrete value.
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousVals);
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> errors = {1, 4};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
EXPECT(assert_equal(expected_error, error_tree));
}
/* ************************************************************************* */
int main() {
TestResult tr;