Endowed hybrid with logProbability

release/4.3a0
Frank Dellaert 2023-01-10 21:55:18 -08:00
parent 11ef99b3f0
commit 426a49dc72
8 changed files with 84 additions and 69 deletions

View File

@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error( AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional. // functor to calculate to double logProbability value from
// GaussianConditional.
auto errorFunc = auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) { [continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) { if (conditional) {
return conditional->error(continuousValues); return conditional->logProbability(continuousValues);
} else { } else {
// Return arbitrarily large error if conditional is null // Return arbitrarily large logProbability if conditional is null
// Conditional is null if it is pruned out. // Conditional is null if it is pruned out.
return 1e50; return 1e50;
} }
@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const { double GaussianMixture::logProbability(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()); return conditional->logProbability(values.continuous());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
const Conditionals &conditionals() const; const Conditionals &conditionals() const;
/** /**
* @brief Compute error of the GaussianMixture as a tree. * @brief Compute logProbability of the GaussianMixture as a tree.
* *
* @param continuousValues The continuous VectorValues. * @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error. * as the conditionals, and leaf values as the logProbability.
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
/** /**
* @brief Compute the error of this Gaussian Mixture given the continuous * @brief Compute the logProbability of this Gaussian Mixture given the
* values and a discrete assignment. * continuous values and a discrete assignment.
* *
* @param values Continuous values and discrete assignment. * @param values Continuous values and discrete assignment.
* @return double * @return double
*/ */
double error(const HybridValues &values) const override; double logProbability(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`. // /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const; // double evaluate(const HybridValues &values) const;

View File

@ -255,11 +255,6 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize(); return gbn.optimize();
} }
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(-error(values));
}
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given, HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const { std::mt19937_64 *rng) const {
@ -296,23 +291,28 @@ HybridValues HybridBayesNet::sample() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error( AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error. // If conditional is hybrid, select based on assignment and compute
error_tree = error_tree + gm->error(continuousValues); // logProbability.
error_tree = error_tree + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) error and add it to the error_tree // If continuous, get the (double) logProbability and add it to the
double error = gc->error(continuousValues); // error_tree
// Add the computed error to every leaf of the error tree. double logProbability = gc->logProbability(continuousValues);
error_tree = error_tree.apply( // Add the computed logProbability to every leaf of the logProbability
[error](double leaf_value) { return leaf_value + error; }); // tree.
error_tree = error_tree.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// TODO(dellaert): if discrete, we need to add error in the right branch? // TODO(dellaert): if discrete, we need to add logProbability in the right
// branch?
continue; continue;
} }
} }
@ -321,10 +321,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
} }
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime( AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return error_tree.apply([](double error) { return exp(-error); }); return tree.apply([](double log) { return exp(log); });
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(logProbability(values));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -187,8 +187,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves); HybridBayesNet prune(size_t maxNrLeaves);
using Base::error; // Expose error(const HybridValues&) method..
/** /**
* @brief Compute conditional error for each discrete assignment, * @brief Compute conditional error for each discrete assignment,
* and return as a tree. * and return as a tree.
@ -196,7 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error. * @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
using BayesNet::logProbability; // expose HybridValues version
/** /**
* @brief Compute unnormalized probability q(μ|M), * @brief Compute unnormalized probability q(μ|M),
@ -208,7 +209,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* probability. * probability.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> probPrime( AlgebraicDecisionTree<Key> evaluate(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const;
/** /**

View File

@ -122,18 +122,18 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const { double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto gc = asGaussian()) { if (auto gc = asGaussian()) {
return gc->error(values.continuous()); return gc->logProbability(values.continuous());
}
if (auto gm = asMixture()) {
return gm->logProbability(values);
} }
if (auto dc = asDiscrete()) { if (auto dc = asDiscrete()) {
return -log((*dc)(values.discrete())); return dc->logProbability(values.discrete());
} }
throw std::runtime_error( throw std::runtime_error(
"HybridConditional::error: conditional type not handled"); "HybridConditional::logProbability: conditional type not handled");
} }
} // namespace gtsam } // namespace gtsam

View File

@ -176,8 +176,8 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() const { return inner_; } boost::shared_ptr<Factor> inner() const { return inner_; }
/// Return the error of the underlying conditional. /// Return the logProbability of the underlying conditional.
double error(const HybridValues& values) const override; double logProbability(const HybridValues& values) const override;
/// Check if VectorValues `measurements` contains all frontal keys. /// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const { bool frontalsIn(const VectorValues& measurements) const {

View File

@ -116,12 +116,12 @@ TEST(GaussianMixture, Error) {
VectorValues values; VectorValues values;
values.insert(X(1), Vector2::Ones()); values.insert(X(1), Vector2::Ones());
values.insert(X(2), Vector2::Zero()); values.insert(X(2), Vector2::Zero());
auto error_tree = mixture.error(values); auto error_tree = mixture.logProbability(values);
// Check result. // Check result.
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> leaves = {conditional0->error(values), std::vector<double> leaves = {conditional0->logProbability(values),
conditional1->error(values)}; conditional1->logProbability(values)};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
EXPECT(assert_equal(expected_error, error_tree, 1e-6)); EXPECT(assert_equal(expected_error, error_tree, 1e-6));
@ -129,11 +129,11 @@ TEST(GaussianMixture, Error) {
// Regression for non-tree version. // Regression for non-tree version.
DiscreteValues assignment; DiscreteValues assignment;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(conditional0->error(values), EXPECT_DOUBLES_EQUAL(conditional0->logProbability(values),
mixture.error({values, assignment}), 1e-8); mixture.logProbability({values, assignment}), 1e-8);
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(conditional1->error(values), EXPECT_DOUBLES_EQUAL(conditional1->logProbability(values),
mixture.error({values, assignment}), 1e-8); mixture.logProbability({values, assignment}), 1e-8);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -64,10 +64,10 @@ TEST(HybridBayesNet, Add) {
// Test evaluate for a pure discrete Bayes net P(Asia). // Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1")); bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6"));
HybridValues values; HybridValues values;
values.insert(asiaKey, 0); values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -207,7 +207,7 @@ TEST(HybridBayesNet, Optimize) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test Bayes net error // Test Bayes net error
TEST(HybridBayesNet, Error) { TEST(HybridBayesNet, logProbability) {
Switching s(3); Switching s(3);
HybridBayesNet::shared_ptr hybridBayesNet = HybridBayesNet::shared_ptr hybridBayesNet =
@ -215,42 +215,49 @@ TEST(HybridBayesNet, Error) {
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous()); auto error_tree = hybridBayesNet->logProbability(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {-4.1609374, -4.1706942, -4.141568, -4.1609374}; std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression // regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6)); EXPECT(assert_equal(expected_error, error_tree, 1e-6));
// Error on pruned Bayes net // logProbability on pruned Bayes net
auto prunedBayesNet = hybridBayesNet->prune(2); auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
std::vector<double> pruned_leaves = {2e50, -4.1706942, 2e50, -4.1609374}; std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys, AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves); pruned_leaves);
// regression // regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6)); EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
// Verify error computation and check for specific error value // Verify logProbability computation and check for specific logProbability
// value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values}; const HybridValues hybridValues{delta.continuous(), discrete_values};
double error = 0; double logProbability = 0;
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues); logProbability +=
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues); hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues);
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues); logProbability +=
hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
// TODO(dellaert): the discrete errors are not added in error tree! // TODO(dellaert): the discrete errors are not added in logProbability tree!
EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
1e-9);
error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
logProbability +=
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
logProbability +=
hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values);
EXPECT_DOUBLES_EQUAL(logProbability,
hybridBayesNet->logProbability(hybridValues), 1e-9);
} }
/* ****************************************************************************/ /* ****************************************************************************/