error method for HybridBayesNet
parent
5387299b8b
commit
95a534e7c1
|
@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
return sample(&kRandomNumberGenerator);
|
return sample(&kRandomNumberGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
|
// Iterate over each conditional.
|
||||||
|
for (auto &&conditional : *this) {
|
||||||
|
if (auto gm = conditional->asMixture()) {
|
||||||
|
// If conditional is hybrid, compute error for all assignments.
|
||||||
|
result = result + gm->error(continuousValues);
|
||||||
|
|
||||||
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
|
// If continuous, get the error and add it to the result
|
||||||
|
double error = gc->error(continuousValues);
|
||||||
|
// Add the computed error to every leaf of the result tree.
|
||||||
|
result = result.apply(
|
||||||
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
|
|
||||||
|
} else if (auto dc = conditional->asDiscrete()) {
|
||||||
|
// If discrete, add the discrete error in the right branch
|
||||||
|
result = result.apply(
|
||||||
|
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||||
|
return leaf_value + dc->error(DiscreteValues(assignment));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
|
|
|
@ -187,6 +187,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param continuousValues Continuous values at which to compute the error.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute log probability for each discrete assignment,
|
||||||
|
* and return as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values at which
|
||||||
|
* to compute the log probability.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree<Key> logProbability(
|
AlgebraicDecisionTree<Key> logProbability(
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
|
|
@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) {
|
||||||
*gbn.at(3)));
|
*gbn.at(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
|
||||||
|
TEST(HybridBayesNet, Error) {
|
||||||
|
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
|
||||||
|
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);
|
||||||
|
|
||||||
|
const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
|
||||||
|
model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0));
|
||||||
|
|
||||||
|
const auto conditional0 = std::make_shared<GaussianConditional>(
|
||||||
|
X(1), Vector1::Constant(5), I_1x1, model0),
|
||||||
|
conditional1 = std::make_shared<GaussianConditional>(
|
||||||
|
X(1), Vector1::Constant(2), I_1x1, model1);
|
||||||
|
|
||||||
|
auto gm =
|
||||||
|
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1});
|
||||||
|
// Create hybrid Bayes net.
|
||||||
|
HybridBayesNet bayesNet;
|
||||||
|
bayesNet.push_back(continuousConditional);
|
||||||
|
bayesNet.emplace_back(gm);
|
||||||
|
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
||||||
|
|
||||||
|
// Create values at which to evaluate.
|
||||||
|
HybridValues values;
|
||||||
|
values.insert(asiaKey, 0);
|
||||||
|
values.insert(X(0), Vector1(-6));
|
||||||
|
values.insert(X(1), Vector1(1));
|
||||||
|
|
||||||
|
AlgebraicDecisionTree<Key> actual_errors =
|
||||||
|
bayesNet.error(values.continuous());
|
||||||
|
|
||||||
|
// Regression.
|
||||||
|
// Manually added all the error values from the 3 conditional types.
|
||||||
|
AlgebraicDecisionTree<Key> expected_errors(
|
||||||
|
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_errors, actual_errors));
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test Bayes net optimize
|
// Test Bayes net optimize
|
||||||
TEST(HybridBayesNet, OptimizeAssignment) {
|
TEST(HybridBayesNet, OptimizeAssignment) {
|
||||||
|
|
Loading…
Reference in New Issue