Add GaussianMixtureFactor::error method and unit test
parent
64744b057e
commit
c41b58fc98
|
|
@ -95,4 +95,16 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
|||
};
|
||||
return {factors_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousVals) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) {
|
||||
return factor->error(continuousVals);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
|
@ -131,17 +132,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
||||
*
|
||||
* @param continuousVals The continuous VectorValues.
|
||||
* @return DecisionTree<Key, double> A decision tree with corresponding keys
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
|
||||
* as the factor but leaf values as the error.
|
||||
*/
|
||||
DecisionTree<Key, double> error(const VectorValues &c) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [c](const GaussianFactor::shared_ptr &factor) {
|
||||
return factor->error(c);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
|
||||
|
||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file GaussianMixtureFactor.cpp
|
||||
* @file testGaussianMixtureFactor.cpp
|
||||
* @brief Unit tests for GaussianMixtureFactor
|
||||
* @author Varun Agrawal
|
||||
* @author Fan Jiang
|
||||
|
|
@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
|||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||
}
|
||||
|
||||
TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
|
||||
TEST(GaussianMixtureFactor, GaussianMixture) {
|
||||
KeyVector keys;
|
||||
keys.push_back(X(0));
|
||||
keys.push_back(X(1));
|
||||
|
|
@ -151,6 +151,39 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) {
|
|||
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test the error of the GaussianMixtureFactor
|
||||
TEST(GaussianMixtureFactor, Error) {
|
||||
DiscreteKey m1(1, 2);
|
||||
|
||||
auto A01 = Matrix2::Identity();
|
||||
auto A02 = Matrix2::Identity();
|
||||
|
||||
auto A11 = Matrix2::Identity();
|
||||
auto A12 = Matrix2::Identity() * 2;
|
||||
|
||||
auto b = Vector2::Zero();
|
||||
|
||||
auto f0 = boost::make_shared<JacobianFactor>(X(1), A01, X(2), A02, b);
|
||||
auto f1 = boost::make_shared<JacobianFactor>(X(1), A11, X(2), A12, b);
|
||||
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||
|
||||
VectorValues continuousVals;
|
||||
continuousVals.insert(X(1), Vector2(0, 0));
|
||||
continuousVals.insert(X(2), Vector2(1, 1));
|
||||
|
||||
// error should return a tree of errors, with nodes for each discrete value.
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousVals);
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> errors = {1, 4};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
|
||||
|
||||
EXPECT(assert_equal(expected_error, error_tree));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue