GaussianMixture error methods

release/4.3a0
Varun Agrawal 2022-11-01 20:19:36 -04:00
parent 9365a02bdb
commit ca14b7e6ec
3 changed files with 77 additions and 2 deletions

View File

@ -208,4 +208,23 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
conditionals_.root_ = pruned_conditionals.root_;
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousVals) const {
// functor to convert from GaussianConditional to double error value.
auto errorFunc =
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousVals);
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousVals,
const DiscreteValues &discreteValues) const {
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousVals);
}
} // namespace gtsam

View File

@ -143,6 +143,26 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();
/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousVals The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
* as the factor but leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousVals The continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousVals,
const DiscreteValues &discreteValues) const;
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.

View File

@ -78,15 +78,51 @@ TEST(GaussianMixture, Equals) {
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals);
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
// Let's check that this worked:
DiscreteValues mode;
mode[m1.first] = 1;
auto actual = mixtureFactor(mode);
auto actual = mixture(mode);
EXPECT(actual == conditional1);
}
/* ************************************************************************* */
/// Test error method of GaussianMixture.
TEST(GaussianMixture, Error) {
Matrix22 S1 = Matrix22::Identity();
Matrix22 S2 = Matrix22::Identity() * 2;
Matrix22 R1 = Matrix22::Ones();
Matrix22 R2 = Matrix22::Ones();
Vector2 d1(1, 2), d2(2, 1);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
X(2), S1, model),
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);
// Create decision tree
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
VectorValues values;
values.insert(X(1), Vector2::Ones());
values.insert(X(2), Vector2::Zero());
auto error_tree = mixture.error(values);
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> leaves = {0.5, 4.3252595};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
}
/* ************************************************************************* */
int main() {
TestResult tr;