add specific assignment error for GaussianMixtureFactor

release/4.3a0
Varun Agrawal 2022-11-01 14:01:20 -04:00
parent c0eeb0cfcd
commit 9365a02bdb
3 changed files with 26 additions and 0 deletions

View File

@ -107,4 +107,12 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
return errorTree;
}
/* *******************************************************************************/
double GaussianMixtureFactor::error(
const VectorValues &continuousVals,
const DiscreteValues &discreteValues) const {
auto factor = factors_(discreteValues);
return factor->error(continuousVals);
}
} // namespace gtsam

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
@ -137,6 +138,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousVals The continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousVals,
const DiscreteValues &discreteValues) const;
/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);

View File

@ -182,6 +182,12 @@ TEST(GaussianMixtureFactor, Error) {
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
EXPECT(assert_equal(expected_error, error_tree));
// Test for single leaf given discrete assignment P(X|M,Z).
DiscreteValues discreteVals;
discreteVals[m1.first] = 1;
EXPECT_DOUBLES_EQUAL(4.0, mixtureFactor.error(continuousVals, discreteVals),
1e-9);
}
/* ************************************************************************* */