error method for HybridBayesNet

release/4.3a0
Varun Agrawal 2023-11-16 17:16:05 -05:00
parent 5387299b8b
commit 95a534e7c1
3 changed files with 79 additions and 0 deletions

View File

@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator); return sample(&kRandomNumberGenerator);
} }
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->error(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability( AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {

View File

@ -187,6 +187,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error. * @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute log probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which
* to compute the log probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> logProbability( AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const;

View File

@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) {
*gbn.at(3))); *gbn.at(3)));
} }
/* ****************************************************************************/
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, Error) {
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);
const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0));
const auto conditional0 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(5), I_1x1, model0),
conditional1 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(2), I_1x1, model1);
auto gm =
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1});
// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.push_back(continuousConditional);
bayesNet.emplace_back(gm);
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
// Create values at which to evaluate.
HybridValues values;
values.insert(asiaKey, 0);
values.insert(X(0), Vector1(-6));
values.insert(X(1), Vector1(1));
AlgebraicDecisionTree<Key> actual_errors =
bayesNet.error(values.continuous());
// Regression.
// Manually added all the error values from the 3 conditional types.
AlgebraicDecisionTree<Key> expected_errors(
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});
EXPECT(assert_equal(expected_errors, actual_errors));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test Bayes net optimize // Test Bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) { TEST(HybridBayesNet, OptimizeAssignment) {