HBN::evaluate
# Conflicts: # gtsam/hybrid/HybridBayesNet.cpp # gtsam/hybrid/tests/testHybridBayesNet.cpprelease/4.3a0
parent
41a96473b5
commit
b04f2f8582
|
|
@ -150,9 +150,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
|
||||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per decisionTree.
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
||||
|
||||
|
|
@ -225,16 +223,50 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
GaussianBayesNet gbn = this->choose(mpe);
|
||||
GaussianBayesNet gbn = choose(mpe);
|
||||
return HybridValues(mpe, gbn.optimize());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn = this->choose(assignment);
|
||||
GaussianBayesNet gbn = choose(assignment);
|
||||
return gbn.optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||
const DiscreteValues& discreteValues = values.discrete();
|
||||
const VectorValues& continuosValues = values.continuous();
|
||||
|
||||
double probability = 1.0;
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment and compute error.
|
||||
// GaussianMixture::shared_ptr gm = conditional->asMixture();
|
||||
// AlgebraicDecisionTree<Key> conditional_error =
|
||||
// gm->error(continuousValues);
|
||||
|
||||
// error_tree = error_tree + conditional_error;
|
||||
|
||||
} else if (conditional->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
// double error = conditional->asGaussian()->error(continuousValues);
|
||||
// // Add the computed error to every leaf of the error tree.
|
||||
// error_tree = error_tree.apply(
|
||||
// [error](double leaf_value) { return leaf_value + error; });
|
||||
} else if (conditional->isDiscrete()) {
|
||||
// Conditional is discrete-only, we skip.
|
||||
probability *=
|
||||
conditional->asDiscreteConditional()->operator()(discreteValues);
|
||||
}
|
||||
}
|
||||
|
||||
return probability;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::sample(const HybridValues &given,
|
||||
std::mt19937_64 *rng) const {
|
||||
|
|
@ -273,7 +305,7 @@ HybridValues HybridBayesNet::sample() const {
|
|||
/* ************************************************************************* */
|
||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
GaussianBayesNet gbn = this->choose(discreteValues);
|
||||
GaussianBayesNet gbn = choose(discreteValues);
|
||||
return gbn.error(continuousValues);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -95,6 +95,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||
|
||||
//** evaluate for given HybridValues */
|
||||
double evaluate(const HybridValues &values) const;
|
||||
|
||||
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||
double operator()(const HybridValues &values) const {
|
||||
return evaluate(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||
* discrete variables and then optimizing the continuous variables based on
|
||||
|
|
|
|||
|
|
@ -36,36 +36,87 @@ using noiseModel::Isotropic;
|
|||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
static const DiscreteKey Asia(0, 2);
|
||||
static const Key asiaKey = 0;
|
||||
static const DiscreteKey Asia(asiaKey, 2);
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test creation
|
||||
// Test creation of a pure discrete Bayes net.
|
||||
TEST(HybridBayesNet, Creation) {
|
||||
HybridBayesNet bayesNet;
|
||||
|
||||
bayesNet.add(Asia, "99/1");
|
||||
|
||||
DiscreteConditional expected(Asia, "99/1");
|
||||
|
||||
CHECK(bayesNet.atDiscrete(0));
|
||||
auto& df = *bayesNet.atDiscrete(0);
|
||||
EXPECT(df.equals(expected));
|
||||
EXPECT(assert_equal(expected, *bayesNet.atDiscrete(0)));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test adding a bayes net to another one.
|
||||
// Test adding a Bayes net to another one.
|
||||
TEST(HybridBayesNet, Add) {
|
||||
HybridBayesNet bayesNet;
|
||||
|
||||
bayesNet.add(Asia, "99/1");
|
||||
|
||||
DiscreteConditional expected(Asia, "99/1");
|
||||
|
||||
HybridBayesNet other;
|
||||
other.push_back(bayesNet);
|
||||
EXPECT(bayesNet.equals(other));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
||||
TEST(HybridBayesNet, evaluatePureDiscrete) {
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.add(Asia, "99/1");
|
||||
HybridValues values;
|
||||
values.insert(asiaKey, 0);
|
||||
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a hybrid Bayes net P(X1|Asia) P(Asia).
|
||||
TEST(HybridBayesNet, evaluateHybrid) {
|
||||
HybridBayesNet bayesNet;
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
const Vector2 d0(1, 2);
|
||||
Matrix22 R0 = Matrix22::Ones();
|
||||
const auto conditional0 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d0, R0, model);
|
||||
|
||||
const Vector2 d1(2, 1);
|
||||
Matrix22 R1 = Matrix22::Ones();
|
||||
const auto conditional1 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, R1, model);
|
||||
|
||||
// TODO(dellaert): creating and adding mixture is clumsy.
|
||||
std::vector<GaussianConditional::shared_ptr> conditionals{conditional0,
|
||||
conditional1};
|
||||
const auto mixture =
|
||||
GaussianMixture::FromConditionals({X(1)}, {}, {Asia}, conditionals);
|
||||
bayesNet.push_back(
|
||||
HybridConditional(boost::make_shared<GaussianMixture>(mixture)));
|
||||
|
||||
// Add component probabilities.
|
||||
bayesNet.add(Asia, "99/1");
|
||||
|
||||
// Create values at which to evaluate.
|
||||
HybridValues values;
|
||||
values.insert(asiaKey, 0);
|
||||
values.insert(X(1), Vector2(1, 2));
|
||||
|
||||
// TODO(dellaert): we need real probabilities!
|
||||
const double conditionalProbability =
|
||||
conditional0->error(values.continuous());
|
||||
EXPECT_DOUBLES_EQUAL(conditionalProbability * 0.99, bayesNet.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
// const Vector2 d1(2, 1);
|
||||
// Matrix22 R1 = Matrix22::Ones();
|
||||
// Matrix22 S1 = Matrix22::Identity() * 2;
|
||||
// const auto conditional1 =
|
||||
// boost::make_shared<GaussianConditional>(X(1), d1, R1, X(2), S1, model);
|
||||
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test choosing an assignment of conditionals
|
||||
TEST(HybridBayesNet, Choose) {
|
||||
|
|
@ -105,7 +156,7 @@ TEST(HybridBayesNet, Choose) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net optimize
|
||||
// Test Bayes net optimize
|
||||
TEST(HybridBayesNet, OptimizeAssignment) {
|
||||
Switching s(4);
|
||||
|
||||
|
|
@ -139,7 +190,7 @@ TEST(HybridBayesNet, OptimizeAssignment) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net optimize
|
||||
// Test Bayes net optimize
|
||||
TEST(HybridBayesNet, Optimize) {
|
||||
Switching s(4);
|
||||
|
||||
|
|
@ -203,7 +254,7 @@ TEST(HybridBayesNet, Error) {
|
|||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-9));
|
||||
|
||||
// Error on pruned bayes net
|
||||
// Error on pruned Bayes net
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
|
||||
|
||||
|
|
@ -238,7 +289,7 @@ TEST(HybridBayesNet, Error) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net pruning
|
||||
// Test Bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
Switching s(4);
|
||||
|
||||
|
|
@ -256,7 +307,7 @@ TEST(HybridBayesNet, Prune) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net updateDiscreteConditionals
|
||||
// Test Bayes net updateDiscreteConditionals
|
||||
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||
Switching s(4);
|
||||
|
||||
|
|
@ -358,7 +409,7 @@ TEST(HybridBayesNet, Sampling) {
|
|||
// Sample
|
||||
HybridValues sample = bn->sample(&gen);
|
||||
|
||||
discrete_samples.push_back(sample.discrete()[M(0)]);
|
||||
discrete_samples.push_back(sample.discrete().at(M(0)));
|
||||
|
||||
if (i == 0) {
|
||||
average_continuous.insert(sample.continuous());
|
||||
|
|
|
|||
Loading…
Reference in New Issue