error method for HybridBayesNet

release/4.3a0
Varun Agrawal 2022-11-02 02:53:51 -04:00
parent ca14b7e6ec
commit 281ad3167e
4 changed files with 115 additions and 1 deletions

View File

@ -214,7 +214,12 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
// functor to convert from GaussianConditional to double error value.
auto errorFunc =
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousVals);
if (conditional) {
return conditional->error(continuousVals);
} else {
// return arbitrarily large error
return 1e50;
}
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;

View File

@ -145,4 +145,45 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize();
}
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
return gbn.error(continuousValues);
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}
return error_tree;
}
} // namespace gtsam

View File

@ -123,6 +123,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;
/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/// @}
private:

View File

@ -183,6 +183,61 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ****************************************************************************/
// Test bayes net error
TEST(HybridBayesNet, Error) {
Switching s(3);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(1), 2}, {M(2), 2}};
std::vector<double> leaves = {0.0097568009, 3.3973404e-31, 0.029126214,
0.0097568009};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
// Error on pruned bayes net
auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
std::vector<double> pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves);
// regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
// Verify error computation and check for specific error value
DiscreteValues discrete_values;
discrete_values[M(1)] = 1;
discrete_values[M(2)] = 1;
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
discrete_values);
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
total_error += error;
}
}
EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
}
/* ****************************************************************************/
// Test bayes net pruning
TEST(HybridBayesNet, Prune) {