diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 7a46d7832..12b56ece8 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -150,9 +150,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Go through all the conditionals in the // Bayes Net and prune them as per decisionTree. - for (size_t i = 0; i < this->size(); i++) { - HybridConditional::shared_ptr conditional = this->at(i); - + for (auto &&conditional : *this) { if (conditional->isHybrid()) { GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); @@ -225,16 +223,50 @@ HybridValues HybridBayesNet::optimize() const { DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); // Given the MPE, compute the optimal continuous values. - GaussianBayesNet gbn = this->choose(mpe); + GaussianBayesNet gbn = choose(mpe); return HybridValues(mpe, gbn.optimize()); } /* ************************************************************************* */ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { - GaussianBayesNet gbn = this->choose(assignment); + GaussianBayesNet gbn = choose(assignment); return gbn.optimize(); } +/* ************************************************************************* */ +double HybridBayesNet::evaluate(const HybridValues &values) const { + const DiscreteValues& discreteValues = values.discrete(); + const VectorValues& continuosValues = values.continuous(); + + double probability = 1.0; + + // Iterate over each conditional. + for (auto &&conditional : *this) { + if (conditional->isHybrid()) { + // If conditional is hybrid, select based on assignment and compute error. + // GaussianMixture::shared_ptr gm = conditional->asMixture(); + // AlgebraicDecisionTree conditional_error = + // gm->error(continuousValues); + + // error_tree = error_tree + conditional_error; + + } else if (conditional->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + // double error = conditional->asGaussian()->error(continuousValues); + // // Add the computed error to every leaf of the error tree. + // error_tree = error_tree.apply( + // [error](double leaf_value) { return leaf_value + error; }); + } else if (conditional->isDiscrete()) { + // Conditional is discrete-only, we skip. + probability *= + conditional->asDiscreteConditional()->operator()(discreteValues); + } + } + + return probability; +} + /* ************************************************************************* */ HybridValues HybridBayesNet::sample(const HybridValues &given, std::mt19937_64 *rng) const { @@ -273,7 +305,7 @@ HybridValues HybridBayesNet::sample() const { /* ************************************************************************* */ double HybridBayesNet::error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { - GaussianBayesNet gbn = this->choose(discreteValues); + GaussianBayesNet gbn = choose(discreteValues); return gbn.error(continuousValues); } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 1e6bebf1a..ff13ca1b7 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -95,6 +95,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ GaussianBayesNet choose(const DiscreteValues &assignment) const; + //** evaluate for given HybridValues */ + double evaluate(const HybridValues &values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const HybridValues &values) const { + return evaluate(values); + } + /** * @brief Solve the HybridBayesNet by first computing the MPE of all the * discrete variables and then optimizing the continuous variables based on diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 3e3fab376..8f9632ba3 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -36,36 +36,87 @@ using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; -static const DiscreteKey Asia(0, 2); +static const Key asiaKey = 0; +static const DiscreteKey Asia(asiaKey, 2); /* ****************************************************************************/ -// Test creation +// Test creation of a pure discrete Bayes net. TEST(HybridBayesNet, Creation) { HybridBayesNet bayesNet; - bayesNet.add(Asia, "99/1"); DiscreteConditional expected(Asia, "99/1"); - CHECK(bayesNet.atDiscrete(0)); - auto& df = *bayesNet.atDiscrete(0); - EXPECT(df.equals(expected)); + EXPECT(assert_equal(expected, *bayesNet.atDiscrete(0))); } /* ****************************************************************************/ -// Test adding a bayes net to another one. +// Test adding a Bayes net to another one. TEST(HybridBayesNet, Add) { HybridBayesNet bayesNet; - bayesNet.add(Asia, "99/1"); - DiscreteConditional expected(Asia, "99/1"); - HybridBayesNet other; other.push_back(bayesNet); EXPECT(bayesNet.equals(other)); } +/* ****************************************************************************/ +// Test evaluate for a pure discrete Bayes net P(Asia). +TEST(HybridBayesNet, evaluatePureDiscrete) { + HybridBayesNet bayesNet; + bayesNet.add(Asia, "99/1"); + HybridValues values; + values.insert(asiaKey, 0); + EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); +} + +/* ****************************************************************************/ +// Test evaluate for a hybrid Bayes net P(X1|Asia) P(Asia). +TEST(HybridBayesNet, evaluateHybrid) { + HybridBayesNet bayesNet; + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + const Vector2 d0(1, 2); + Matrix22 R0 = Matrix22::Ones(); + const auto conditional0 = + boost::make_shared(X(1), d0, R0, model); + + const Vector2 d1(2, 1); + Matrix22 R1 = Matrix22::Ones(); + const auto conditional1 = + boost::make_shared(X(1), d1, R1, model); + + // TODO(dellaert): creating and adding mixture is clumsy. + std::vector conditionals{conditional0, + conditional1}; + const auto mixture = + GaussianMixture::FromConditionals({X(1)}, {}, {Asia}, conditionals); + bayesNet.push_back( + HybridConditional(boost::make_shared(mixture))); + + // Add component probabilities. + bayesNet.add(Asia, "99/1"); + + // Create values at which to evaluate. + HybridValues values; + values.insert(asiaKey, 0); + values.insert(X(1), Vector2(1, 2)); + + // TODO(dellaert): we need real probabilities! + const double conditionalProbability = + conditional0->error(values.continuous()); + EXPECT_DOUBLES_EQUAL(conditionalProbability * 0.99, bayesNet.evaluate(values), 1e-9); +} + + // const Vector2 d1(2, 1); + // Matrix22 R1 = Matrix22::Ones(); + // Matrix22 S1 = Matrix22::Identity() * 2; + // const auto conditional1 = + // boost::make_shared(X(1), d1, R1, X(2), S1, model); + + /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { @@ -105,7 +156,7 @@ TEST(HybridBayesNet, Choose) { } /* ****************************************************************************/ -// Test bayes net optimize +// Test Bayes net optimize TEST(HybridBayesNet, OptimizeAssignment) { Switching s(4); @@ -139,7 +190,7 @@ TEST(HybridBayesNet, OptimizeAssignment) { } /* ****************************************************************************/ -// Test bayes net optimize +// Test Bayes net optimize TEST(HybridBayesNet, Optimize) { Switching s(4); @@ -203,7 +254,7 @@ TEST(HybridBayesNet, Error) { // regression EXPECT(assert_equal(expected_error, error_tree, 1e-9)); - // Error on pruned bayes net + // Error on pruned Bayes net auto prunedBayesNet = hybridBayesNet->prune(2); auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); @@ -238,7 +289,7 @@ TEST(HybridBayesNet, Error) { } /* ****************************************************************************/ -// Test bayes net pruning +// Test Bayes net pruning TEST(HybridBayesNet, Prune) { Switching s(4); @@ -256,7 +307,7 @@ TEST(HybridBayesNet, Prune) { } /* ****************************************************************************/ -// Test bayes net updateDiscreteConditionals +// Test Bayes net updateDiscreteConditionals TEST(HybridBayesNet, UpdateDiscreteConditionals) { Switching s(4); @@ -358,7 +409,7 @@ TEST(HybridBayesNet, Sampling) { // Sample HybridValues sample = bn->sample(&gen); - discrete_samples.push_back(sample.discrete()[M(0)]); + discrete_samples.push_back(sample.discrete().at(M(0))); if (i == 0) { average_continuous.insert(sample.continuous());