error method for HybridBayesNet
parent
ca14b7e6ec
commit
281ad3167e
|
|
@ -214,7 +214,12 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
|||
// functor to convert from GaussianConditional to double error value.
|
||||
auto errorFunc =
|
||||
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->error(continuousVals);
|
||||
if (conditional) {
|
||||
return conditional->error(continuousVals);
|
||||
} else {
|
||||
// return arbitrarily large error
|
||||
return 1e50;
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||
return errorTree;
|
||||
|
|
|
|||
|
|
@ -145,4 +145,45 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
|||
return gbn.optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
GaussianBayesNet gbn = this->choose(discreteValues);
|
||||
return gbn.error(continuousValues);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree;
|
||||
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> conditional_error;
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
||||
conditional_error = gm->error(continuousValues);
|
||||
|
||||
if (idx == 0) {
|
||||
error_tree = conditional_error;
|
||||
} else {
|
||||
error_tree = error_tree + conditional_error;
|
||||
}
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
double error = this->atGaussian(idx)->error(continuousValues);
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -123,6 +123,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
|
||||
/**
|
||||
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||
* for a specific discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues Discrete assignment for a specific mode sequence.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -183,6 +183,61 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net error
|
||||
TEST(HybridBayesNet, Error) {
|
||||
Switching s(3);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = hybridBayesNet->error(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(1), 2}, {M(2), 2}};
|
||||
std::vector<double> leaves = {0.0097568009, 3.3973404e-31, 0.029126214,
|
||||
0.0097568009};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
|
||||
|
||||
// Error on pruned bayes net
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
|
||||
|
||||
std::vector<double> pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009};
|
||||
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
||||
pruned_leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
|
||||
|
||||
// Verify error computation and check for specific error value
|
||||
DiscreteValues discrete_values;
|
||||
discrete_values[M(1)] = 1;
|
||||
discrete_values[M(2)] = 1;
|
||||
|
||||
double total_error = 0;
|
||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||
if (hybridBayesNet->at(idx)->isHybrid()) {
|
||||
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
|
||||
discrete_values);
|
||||
total_error += error;
|
||||
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
||||
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
|
||||
total_error += error;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
|
||||
1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue