diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index cabfd28b8..c5ffed27b 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { } /* *******************************************************************************/ -AlgebraicDecisionTree GaussianMixture::error( +AlgebraicDecisionTree GaussianMixture::logProbability( const VectorValues &continuousValues) const { - // functor to calculate to double error value from GaussianConditional. + // functor to calculate to double logProbability value from + // GaussianConditional. auto errorFunc = [continuousValues](const GaussianConditional::shared_ptr &conditional) { if (conditional) { - return conditional->error(continuousValues); + return conditional->logProbability(continuousValues); } else { - // Return arbitrarily large error if conditional is null + // Return arbitrarily large logProbability if conditional is null // Conditional is null if it is pruned out. return 1e50; } @@ -289,10 +290,10 @@ AlgebraicDecisionTree GaussianMixture::error( } /* *******************************************************************************/ -double GaussianMixture::error(const HybridValues &values) const { +double GaussianMixture::logProbability(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - return conditional->error(values.continuous()); + return conditional->logProbability(values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a8010e17c..a9f82d555 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture const Conditionals &conditionals() const; /** - * @brief Compute error of the GaussianMixture as a tree. + * @brief Compute logProbability of the GaussianMixture as a tree. * * @param continuousValues The continuous VectorValues. * @return AlgebraicDecisionTree A decision tree with the same keys - * as the conditionals, and leaf values as the error. + * as the conditionals, and leaf values as the logProbability. */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree logProbability( + const VectorValues &continuousValues) const; /** - * @brief Compute the error of this Gaussian Mixture given the continuous - * values and a discrete assignment. + * @brief Compute the logProbability of this Gaussian Mixture given the + * continuous values and a discrete assignment. * * @param values Continuous values and discrete assignment. * @return double */ - double error(const HybridValues &values) const override; + double logProbability(const HybridValues &values) const override; // /// Calculate probability density for given values `x`. // double evaluate(const HybridValues &values) const; diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e662b9c81..bc0d8e95e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -255,11 +255,6 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { return gbn.optimize(); } -/* ************************************************************************* */ -double HybridBayesNet::evaluate(const HybridValues &values) const { - return exp(-error(values)); -} - /* ************************************************************************* */ HybridValues HybridBayesNet::sample(const HybridValues &given, std::mt19937_64 *rng) const { @@ -296,23 +291,28 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::error( +AlgebraicDecisionTree HybridBayesNet::logProbability( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); // Iterate over each conditional. for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { - // If conditional is hybrid, select based on assignment and compute error. - error_tree = error_tree + gm->error(continuousValues); + // If conditional is hybrid, select based on assignment and compute + // logProbability. + error_tree = error_tree + gm->logProbability(continuousValues); } else if (auto gc = conditional->asGaussian()) { - // If continuous, get the (double) error and add it to the error_tree - double error = gc->error(continuousValues); - // Add the computed error to every leaf of the error tree. - error_tree = error_tree.apply( - [error](double leaf_value) { return leaf_value + error; }); + // If continuous, get the (double) logProbability and add it to the + // error_tree + double logProbability = gc->logProbability(continuousValues); + // Add the computed logProbability to every leaf of the logProbability + // tree. + error_tree = error_tree.apply([logProbability](double leaf_value) { + return leaf_value + logProbability; + }); } else if (auto dc = conditional->asDiscrete()) { - // TODO(dellaert): if discrete, we need to add error in the right branch? + // TODO(dellaert): if discrete, we need to add logProbability in the right + // branch? continue; } } @@ -321,10 +321,15 @@ AlgebraicDecisionTree HybridBayesNet::error( } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::probPrime( +AlgebraicDecisionTree HybridBayesNet::evaluate( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree = this->error(continuousValues); - return error_tree.apply([](double error) { return exp(-error); }); + AlgebraicDecisionTree tree = this->logProbability(continuousValues); + return tree.apply([](double log) { return exp(log); }); +} + +/* ************************************************************************* */ +double HybridBayesNet::evaluate(const HybridValues &values) const { + return exp(logProbability(values)); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index c5a16f9dd..46a2b4f77 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -187,8 +187,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. HybridBayesNet prune(size_t maxNrLeaves); - using Base::error; // Expose error(const HybridValues&) method.. - /** * @brief Compute conditional error for each discrete assignment, * and return as a tree. @@ -196,7 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree logProbability( + const VectorValues &continuousValues) const; + + using BayesNet::logProbability; // expose HybridValues version /** * @brief Compute unnormalized probability q(μ|M), @@ -208,7 +209,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * probability. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree probPrime( + AlgebraicDecisionTree evaluate( const VectorValues &continuousValues) const; /** diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index f10a692af..df92ffcb8 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -122,18 +122,18 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { } /* ************************************************************************ */ -double HybridConditional::error(const HybridValues &values) const { - if (auto gm = asMixture()) { - return gm->error(values); - } +double HybridConditional::logProbability(const HybridValues &values) const { if (auto gc = asGaussian()) { - return gc->error(values.continuous()); + return gc->logProbability(values.continuous()); + } + if (auto gm = asMixture()) { + return gm->logProbability(values); } if (auto dc = asDiscrete()) { - return -log((*dc)(values.discrete())); + return dc->logProbability(values.discrete()); } throw std::runtime_error( - "HybridConditional::error: conditional type not handled"); + "HybridConditional::logProbability: conditional type not handled"); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 6199fe7b0..030e6c835 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -176,8 +176,8 @@ class GTSAM_EXPORT HybridConditional /// Get the type-erased pointer to the inner type boost::shared_ptr inner() const { return inner_; } - /// Return the error of the underlying conditional. - double error(const HybridValues& values) const override; + /// Return the logProbability of the underlying conditional. + double logProbability(const HybridValues& values) const override; /// Check if VectorValues `measurements` contains all frontal keys. bool frontalsIn(const VectorValues& measurements) const { diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 0b2866921..a2ee2c21f 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -116,12 +116,12 @@ TEST(GaussianMixture, Error) { VectorValues values; values.insert(X(1), Vector2::Ones()); values.insert(X(2), Vector2::Zero()); - auto error_tree = mixture.error(values); + auto error_tree = mixture.logProbability(values); // Check result. std::vector discrete_keys = {m1}; - std::vector leaves = {conditional0->error(values), - conditional1->error(values)}; + std::vector leaves = {conditional0->logProbability(values), + conditional1->logProbability(values)}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); EXPECT(assert_equal(expected_error, error_tree, 1e-6)); @@ -129,11 +129,11 @@ TEST(GaussianMixture, Error) { // Regression for non-tree version. DiscreteValues assignment; assignment[M(1)] = 0; - EXPECT_DOUBLES_EQUAL(conditional0->error(values), - mixture.error({values, assignment}), 1e-8); + EXPECT_DOUBLES_EQUAL(conditional0->logProbability(values), + mixture.logProbability({values, assignment}), 1e-8); assignment[M(1)] = 1; - EXPECT_DOUBLES_EQUAL(conditional1->error(values), - mixture.error({values, assignment}), 1e-8); + EXPECT_DOUBLES_EQUAL(conditional1->logProbability(values), + mixture.logProbability({values, assignment}), 1e-8); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 2c114a335..3af131f09 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -64,10 +64,10 @@ TEST(HybridBayesNet, Add) { // Test evaluate for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; - bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1")); + bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6")); HybridValues values; values.insert(asiaKey, 0); - EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); + EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9); } /* ****************************************************************************/ @@ -207,7 +207,7 @@ TEST(HybridBayesNet, Optimize) { /* ****************************************************************************/ // Test Bayes net error -TEST(HybridBayesNet, Error) { +TEST(HybridBayesNet, logProbability) { Switching s(3); HybridBayesNet::shared_ptr hybridBayesNet = @@ -215,42 +215,49 @@ TEST(HybridBayesNet, Error) { EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); HybridValues delta = hybridBayesNet->optimize(); - auto error_tree = hybridBayesNet->error(delta.continuous()); + auto error_tree = hybridBayesNet->logProbability(delta.continuous()); std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; - std::vector leaves = {-4.1609374, -4.1706942, -4.141568, -4.1609374}; + std::vector leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); // regression EXPECT(assert_equal(expected_error, error_tree, 1e-6)); - // Error on pruned Bayes net + // logProbability on pruned Bayes net auto prunedBayesNet = hybridBayesNet->prune(2); - auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); + auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous()); - std::vector pruned_leaves = {2e50, -4.1706942, 2e50, -4.1609374}; + std::vector pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374}; AlgebraicDecisionTree expected_pruned_error(discrete_keys, pruned_leaves); // regression EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6)); - // Verify error computation and check for specific error value + // Verify logProbability computation and check for specific logProbability + // value const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const HybridValues hybridValues{delta.continuous(), discrete_values}; - double error = 0; - error += hybridBayesNet->at(0)->asMixture()->error(hybridValues); - error += hybridBayesNet->at(1)->asMixture()->error(hybridValues); - error += hybridBayesNet->at(2)->asMixture()->error(hybridValues); + double logProbability = 0; + logProbability += + hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues); + logProbability += + hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues); + logProbability += + hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues); - // TODO(dellaert): the discrete errors are not added in error tree! - EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9); - EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9); - - error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values); - error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values); - EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9); + // TODO(dellaert): the discrete errors are not added in logProbability tree! + EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values), + 1e-9); + logProbability += + hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values); + logProbability += + hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values); + EXPECT_DOUBLES_EQUAL(logProbability, + hybridBayesNet->logProbability(hybridValues), 1e-9); } /* ****************************************************************************/