Add GaussianMixtureFactor::error method and unit test
parent
64744b057e
commit
c41b58fc98
|
|
@ -95,4 +95,16 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
};
|
};
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||||
|
const VectorValues &continuousVals) const {
|
||||||
|
// functor to convert from sharedFactor to double error value.
|
||||||
|
auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) {
|
||||||
|
return factor->error(continuousVals);
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||||
|
return errorTree;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
|
@ -131,17 +132,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
||||||
*
|
*
|
||||||
* @param continuousVals The continuous VectorValues.
|
* @param continuousVals The continuous VectorValues.
|
||||||
* @return DecisionTree<Key, double> A decision tree with corresponding keys
|
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
|
||||||
* as the factor but leaf values as the error.
|
* as the factor but leaf values as the error.
|
||||||
*/
|
*/
|
||||||
DecisionTree<Key, double> error(const VectorValues &c) const {
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
|
||||||
// functor to convert from sharedFactor to double error value.
|
|
||||||
auto errorFunc = [c](const GaussianFactor::shared_ptr &factor) {
|
|
||||||
return factor->error(c);
|
|
||||||
};
|
|
||||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
|
||||||
return errorTree;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file GaussianMixtureFactor.cpp
|
* @file testGaussianMixtureFactor.cpp
|
||||||
* @brief Unit tests for GaussianMixtureFactor
|
* @brief Unit tests for GaussianMixtureFactor
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
|
@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
|
TEST(GaussianMixtureFactor, GaussianMixture) {
|
||||||
KeyVector keys;
|
KeyVector keys;
|
||||||
keys.push_back(X(0));
|
keys.push_back(X(0));
|
||||||
keys.push_back(X(1));
|
keys.push_back(X(1));
|
||||||
|
|
@ -151,6 +151,39 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
|
||||||
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
|
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test the error of the GaussianMixtureFactor
|
||||||
|
TEST(GaussianMixtureFactor, Error) {
|
||||||
|
DiscreteKey m1(1, 2);
|
||||||
|
|
||||||
|
auto A01 = Matrix2::Identity();
|
||||||
|
auto A02 = Matrix2::Identity();
|
||||||
|
|
||||||
|
auto A11 = Matrix2::Identity();
|
||||||
|
auto A12 = Matrix2::Identity() * 2;
|
||||||
|
|
||||||
|
auto b = Vector2::Zero();
|
||||||
|
|
||||||
|
auto f0 = boost::make_shared<JacobianFactor>(X(1), A01, X(2), A02, b);
|
||||||
|
auto f1 = boost::make_shared<JacobianFactor>(X(1), A11, X(2), A12, b);
|
||||||
|
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
|
||||||
|
|
||||||
|
GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||||
|
|
||||||
|
VectorValues continuousVals;
|
||||||
|
continuousVals.insert(X(1), Vector2(0, 0));
|
||||||
|
continuousVals.insert(X(2), Vector2(1, 1));
|
||||||
|
|
||||||
|
// error should return a tree of errors, with nodes for each discrete value.
|
||||||
|
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousVals);
|
||||||
|
|
||||||
|
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||||
|
std::vector<double> errors = {1, 4};
|
||||||
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_error, error_tree));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue