GaussianMixture error methods
parent
9365a02bdb
commit
ca14b7e6ec
|
@ -208,4 +208,23 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||||
conditionals_.root_ = pruned_conditionals.root_;
|
conditionals_.root_ = pruned_conditionals.root_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
|
const VectorValues &continuousVals) const {
|
||||||
|
// functor to convert from GaussianConditional to double error value.
|
||||||
|
auto errorFunc =
|
||||||
|
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
return conditional->error(continuousVals);
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||||
|
return errorTree;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::error(const VectorValues &continuousVals,
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
auto conditional = conditionals_(discreteValues);
|
||||||
|
return conditional->error(continuousVals);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -143,6 +143,26 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Getter for the underlying Conditionals DecisionTree
|
||||||
const Conditionals &conditionals();
|
const Conditionals &conditionals();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error of the GaussianMixture as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousVals The continuous VectorValues.
|
||||||
|
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
|
||||||
|
* as the factor but leaf values as the error.
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||||
|
* values and a discrete assignment.
|
||||||
|
*
|
||||||
|
* @param continuousVals The continuous values at which to compute the error.
|
||||||
|
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const VectorValues &continuousVals,
|
||||||
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
* `decisionTree`.
|
* `decisionTree`.
|
||||||
|
|
|
@ -78,15 +78,51 @@ TEST(GaussianMixture, Equals) {
|
||||||
GaussianMixture::Conditionals conditionals(
|
GaussianMixture::Conditionals conditionals(
|
||||||
{m1},
|
{m1},
|
||||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals);
|
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||||
|
|
||||||
// Let's check that this worked:
|
// Let's check that this worked:
|
||||||
DiscreteValues mode;
|
DiscreteValues mode;
|
||||||
mode[m1.first] = 1;
|
mode[m1.first] = 1;
|
||||||
auto actual = mixtureFactor(mode);
|
auto actual = mixture(mode);
|
||||||
EXPECT(actual == conditional1);
|
EXPECT(actual == conditional1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/// Test error method of GaussianMixture.
|
||||||
|
TEST(GaussianMixture, Error) {
|
||||||
|
Matrix22 S1 = Matrix22::Identity();
|
||||||
|
Matrix22 S2 = Matrix22::Identity() * 2;
|
||||||
|
Matrix22 R1 = Matrix22::Ones();
|
||||||
|
Matrix22 R2 = Matrix22::Ones();
|
||||||
|
Vector2 d1(1, 2), d2(2, 1);
|
||||||
|
|
||||||
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
|
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
|
||||||
|
X(2), S1, model),
|
||||||
|
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||||
|
X(2), S2, model);
|
||||||
|
|
||||||
|
// Create decision tree
|
||||||
|
DiscreteKey m1(1, 2);
|
||||||
|
GaussianMixture::Conditionals conditionals(
|
||||||
|
{m1},
|
||||||
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
|
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||||
|
|
||||||
|
VectorValues values;
|
||||||
|
values.insert(X(1), Vector2::Ones());
|
||||||
|
values.insert(X(2), Vector2::Zero());
|
||||||
|
auto error_tree = mixture.error(values);
|
||||||
|
|
||||||
|
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||||
|
std::vector<double> leaves = {0.5, 4.3252595};
|
||||||
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||||
|
|
||||||
|
// regression
|
||||||
|
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue