HBN::evaluate

# Conflicts:
#	gtsam/hybrid/HybridBayesNet.cpp
#	gtsam/hybrid/tests/testHybridBayesNet.cpp
release/4.3a0
Frank Dellaert 2022-12-28 08:18:00 -05:00
parent 41a96473b5
commit b04f2f8582
3 changed files with 113 additions and 22 deletions

View File

@ -150,9 +150,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
@ -225,16 +223,50 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
GaussianBayesNet gbn = choose(mpe);
return HybridValues(mpe, gbn.optimize());
}
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
GaussianBayesNet gbn = choose(assignment);
return gbn.optimize();
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues& discreteValues = values.discrete();
const VectorValues& continuosValues = values.continuous();
double probability = 1.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment and compute error.
// GaussianMixture::shared_ptr gm = conditional->asMixture();
// AlgebraicDecisionTree<Key> conditional_error =
// gm->error(continuousValues);
// error_tree = error_tree + conditional_error;
} else if (conditional->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
// double error = conditional->asGaussian()->error(continuousValues);
// // Add the computed error to every leaf of the error tree.
// error_tree = error_tree.apply(
// [error](double leaf_value) { return leaf_value + error; });
} else if (conditional->isDiscrete()) {
// Conditional is discrete-only, we skip.
probability *=
conditional->asDiscreteConditional()->operator()(discreteValues);
}
}
return probability;
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
@ -273,7 +305,7 @@ HybridValues HybridBayesNet::sample() const {
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
GaussianBayesNet gbn = choose(discreteValues);
return gbn.error(continuousValues);
}

View File

@ -95,6 +95,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;
//** evaluate for given HybridValues */
double evaluate(const HybridValues &values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const HybridValues &values) const {
return evaluate(values);
}
/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on

View File

@ -36,36 +36,87 @@ using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
static const DiscreteKey Asia(0, 2);
static const Key asiaKey = 0;
static const DiscreteKey Asia(asiaKey, 2);
/* ****************************************************************************/
// Test creation
// Test creation of a pure discrete Bayes net.
TEST(HybridBayesNet, Creation) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
DiscreteConditional expected(Asia, "99/1");
CHECK(bayesNet.atDiscrete(0));
auto& df = *bayesNet.atDiscrete(0);
EXPECT(df.equals(expected));
EXPECT(assert_equal(expected, *bayesNet.atDiscrete(0)));
}
/* ****************************************************************************/
// Test adding a bayes net to another one.
// Test adding a Bayes net to another one.
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
DiscreteConditional expected(Asia, "99/1");
HybridBayesNet other;
other.push_back(bayesNet);
EXPECT(bayesNet.equals(other));
}
/* ****************************************************************************/
// Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, evaluatePureDiscrete) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
HybridValues values;
values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
}
/* ****************************************************************************/
// Test evaluate for a hybrid Bayes net P(X1|Asia) P(Asia).
TEST(HybridBayesNet, evaluateHybrid) {
HybridBayesNet bayesNet;
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
const Vector2 d0(1, 2);
Matrix22 R0 = Matrix22::Ones();
const auto conditional0 =
boost::make_shared<GaussianConditional>(X(1), d0, R0, model);
const Vector2 d1(2, 1);
Matrix22 R1 = Matrix22::Ones();
const auto conditional1 =
boost::make_shared<GaussianConditional>(X(1), d1, R1, model);
// TODO(dellaert): creating and adding mixture is clumsy.
std::vector<GaussianConditional::shared_ptr> conditionals{conditional0,
conditional1};
const auto mixture =
GaussianMixture::FromConditionals({X(1)}, {}, {Asia}, conditionals);
bayesNet.push_back(
HybridConditional(boost::make_shared<GaussianMixture>(mixture)));
// Add component probabilities.
bayesNet.add(Asia, "99/1");
// Create values at which to evaluate.
HybridValues values;
values.insert(asiaKey, 0);
values.insert(X(1), Vector2(1, 2));
// TODO(dellaert): we need real probabilities!
const double conditionalProbability =
conditional0->error(values.continuous());
EXPECT_DOUBLES_EQUAL(conditionalProbability * 0.99, bayesNet.evaluate(values), 1e-9);
}
// const Vector2 d1(2, 1);
// Matrix22 R1 = Matrix22::Ones();
// Matrix22 S1 = Matrix22::Identity() * 2;
// const auto conditional1 =
// boost::make_shared<GaussianConditional>(X(1), d1, R1, X(2), S1, model);
/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
@ -105,7 +156,7 @@ TEST(HybridBayesNet, Choose) {
}
/* ****************************************************************************/
// Test bayes net optimize
// Test Bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) {
Switching s(4);
@ -139,7 +190,7 @@ TEST(HybridBayesNet, OptimizeAssignment) {
}
/* ****************************************************************************/
// Test bayes net optimize
// Test Bayes net optimize
TEST(HybridBayesNet, Optimize) {
Switching s(4);
@ -203,7 +254,7 @@ TEST(HybridBayesNet, Error) {
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
// Error on pruned bayes net
// Error on pruned Bayes net
auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
@ -238,7 +289,7 @@ TEST(HybridBayesNet, Error) {
}
/* ****************************************************************************/
// Test bayes net pruning
// Test Bayes net pruning
TEST(HybridBayesNet, Prune) {
Switching s(4);
@ -256,7 +307,7 @@ TEST(HybridBayesNet, Prune) {
}
/* ****************************************************************************/
// Test bayes net updateDiscreteConditionals
// Test Bayes net updateDiscreteConditionals
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
Switching s(4);
@ -358,7 +409,7 @@ TEST(HybridBayesNet, Sampling) {
// Sample
HybridValues sample = bn->sample(&gen);
discrete_samples.push_back(sample.discrete()[M(0)]);
discrete_samples.push_back(sample.discrete().at(M(0)));
if (i == 0) {
average_continuous.insert(sample.continuous());