error method for HybridBayesNet
parent
5387299b8b
commit
95a534e7c1
|
@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const {
|
|||
return sample(&kRandomNumberGenerator);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, compute error for all assignments.
|
||||
result = result + gm->error(continuousValues);
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the error and add it to the result
|
||||
double error = gc->error(continuousValues);
|
||||
// Add the computed error to every leaf of the result tree.
|
||||
result = result.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// If discrete, add the discrete error in the right branch
|
||||
result = result.apply(
|
||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||
return leaf_value + dc->error(DiscreteValues(assignment));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||
const VectorValues &continuousValues) const {
|
||||
|
|
|
@ -187,6 +187,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute log probability for each discrete assignment,
|
||||
* and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which
|
||||
* to compute the log probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> logProbability(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
|
|
|
@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) {
|
|||
*gbn.at(3)));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
|
||||
TEST(HybridBayesNet, Error) {
|
||||
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
|
||||
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);
|
||||
|
||||
const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
|
||||
model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0));
|
||||
|
||||
const auto conditional0 = std::make_shared<GaussianConditional>(
|
||||
X(1), Vector1::Constant(5), I_1x1, model0),
|
||||
conditional1 = std::make_shared<GaussianConditional>(
|
||||
X(1), Vector1::Constant(2), I_1x1, model1);
|
||||
|
||||
auto gm =
|
||||
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1});
|
||||
// Create hybrid Bayes net.
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.push_back(continuousConditional);
|
||||
bayesNet.emplace_back(gm);
|
||||
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
||||
|
||||
// Create values at which to evaluate.
|
||||
HybridValues values;
|
||||
values.insert(asiaKey, 0);
|
||||
values.insert(X(0), Vector1(-6));
|
||||
values.insert(X(1), Vector1(1));
|
||||
|
||||
AlgebraicDecisionTree<Key> actual_errors =
|
||||
bayesNet.error(values.continuous());
|
||||
|
||||
// Regression.
|
||||
// Manually added all the error values from the 3 conditional types.
|
||||
AlgebraicDecisionTree<Key> expected_errors(
|
||||
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});
|
||||
|
||||
EXPECT(assert_equal(expected_errors, actual_errors));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net optimize
|
||||
TEST(HybridBayesNet, OptimizeAssignment) {
|
||||
|
|
Loading…
Reference in New Issue