GaussianMixture error methods
parent
9365a02bdb
commit
ca14b7e6ec
|
@ -208,4 +208,23 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
|||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||
const VectorValues &continuousVals) const {
|
||||
// functor to convert from GaussianConditional to double error value.
|
||||
auto errorFunc =
|
||||
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->error(continuousVals);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const VectorValues &continuousVals,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
auto conditional = conditionals_(discreteValues);
|
||||
return conditional->error(continuousVals);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -143,6 +143,26 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals();
|
||||
|
||||
/**
|
||||
* @brief Compute error of the GaussianMixture as a tree.
|
||||
*
|
||||
* @param continuousVals The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
|
||||
* as the factor but leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param continuousVals The continuous values at which to compute the error.
|
||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousVals,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||
* `decisionTree`.
|
||||
|
|
|
@ -78,15 +78,51 @@ TEST(GaussianMixture, Equals) {
|
|||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
// Let's check that this worked:
|
||||
DiscreteValues mode;
|
||||
mode[m1.first] = 1;
|
||||
auto actual = mixtureFactor(mode);
|
||||
auto actual = mixture(mode);
|
||||
EXPECT(actual == conditional1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Test error method of GaussianMixture.
|
||||
TEST(GaussianMixture, Error) {
|
||||
Matrix22 S1 = Matrix22::Identity();
|
||||
Matrix22 S2 = Matrix22::Identity() * 2;
|
||||
Matrix22 R1 = Matrix22::Ones();
|
||||
Matrix22 R2 = Matrix22::Ones();
|
||||
Vector2 d1(1, 2), d2(2, 1);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
|
||||
X(2), S1, model),
|
||||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||
X(2), S2, model);
|
||||
|
||||
// Create decision tree
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
VectorValues values;
|
||||
values.insert(X(1), Vector2::Ones());
|
||||
values.insert(X(2), Vector2::Zero());
|
||||
auto error_tree = mixture.error(values);
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> leaves = {0.5, 4.3252595};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue