error method for HybridBayesNet
parent
ca14b7e6ec
commit
281ad3167e
|
|
@ -214,7 +214,12 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
// functor to convert from GaussianConditional to double error value.
|
// functor to convert from GaussianConditional to double error value.
|
||||||
auto errorFunc =
|
auto errorFunc =
|
||||||
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
|
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
|
||||||
return conditional->error(continuousVals);
|
if (conditional) {
|
||||||
|
return conditional->error(continuousVals);
|
||||||
|
} else {
|
||||||
|
// return arbitrarily large error
|
||||||
|
return 1e50;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||||
return errorTree;
|
return errorTree;
|
||||||
|
|
|
||||||
|
|
@ -145,4 +145,45 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
return gbn.optimize();
|
return gbn.optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
GaussianBayesNet gbn = this->choose(discreteValues);
|
||||||
|
return gbn.error(continuousValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> error_tree;
|
||||||
|
|
||||||
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
AlgebraicDecisionTree<Key> conditional_error;
|
||||||
|
if (factors_.at(idx)->isHybrid()) {
|
||||||
|
// If factor is hybrid, select based on assignment.
|
||||||
|
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
||||||
|
conditional_error = gm->error(continuousValues);
|
||||||
|
|
||||||
|
if (idx == 0) {
|
||||||
|
error_tree = conditional_error;
|
||||||
|
} else {
|
||||||
|
error_tree = error_tree + conditional_error;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (factors_.at(idx)->isContinuous()) {
|
||||||
|
// If continuous only, get the (double) error
|
||||||
|
// and add it to the error_tree
|
||||||
|
double error = this->atGaussian(idx)->error(continuousValues);
|
||||||
|
error_tree = error_tree.apply(
|
||||||
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
|
|
||||||
|
} else if (factors_.at(idx)->isDiscrete()) {
|
||||||
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return error_tree;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||||
|
* for a specific discrete assignment.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values at which to compute the error.
|
||||||
|
* @param discreteValues Discrete assignment for a specific mode sequence.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const VectorValues &continuousValues,
|
||||||
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,61 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test bayes net error
|
||||||
|
TEST(HybridBayesNet, Error) {
|
||||||
|
Switching s(3);
|
||||||
|
|
||||||
|
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
|
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
auto error_tree = hybridBayesNet->error(delta.continuous());
|
||||||
|
|
||||||
|
std::vector<DiscreteKey> discrete_keys = {{M(1), 2}, {M(2), 2}};
|
||||||
|
std::vector<double> leaves = {0.0097568009, 3.3973404e-31, 0.029126214,
|
||||||
|
0.0097568009};
|
||||||
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||||
|
|
||||||
|
// regression
|
||||||
|
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
|
||||||
|
|
||||||
|
// Error on pruned bayes net
|
||||||
|
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||||
|
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
|
||||||
|
|
||||||
|
std::vector<double> pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009};
|
||||||
|
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
||||||
|
pruned_leaves);
|
||||||
|
|
||||||
|
// regression
|
||||||
|
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
|
||||||
|
|
||||||
|
// Verify error computation and check for specific error value
|
||||||
|
DiscreteValues discrete_values;
|
||||||
|
discrete_values[M(1)] = 1;
|
||||||
|
discrete_values[M(2)] = 1;
|
||||||
|
|
||||||
|
double total_error = 0;
|
||||||
|
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||||
|
if (hybridBayesNet->at(idx)->isHybrid()) {
|
||||||
|
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
|
||||||
|
discrete_values);
|
||||||
|
total_error += error;
|
||||||
|
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
||||||
|
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
|
||||||
|
total_error += error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_DOUBLES_EQUAL(
|
||||||
|
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
|
||||||
|
1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test bayes net pruning
|
// Test bayes net pruning
|
||||||
TEST(HybridBayesNet, Prune) {
|
TEST(HybridBayesNet, Prune) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue