Endowed hybrid with logProbability

release/4.3a0
Frank Dellaert 2023-01-10 21:55:18 -08:00
parent 11ef99b3f0
commit 426a49dc72
8 changed files with 84 additions and 69 deletions

View File

@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional.
// functor to calculate to double logProbability value from
// GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
return conditional->logProbability(continuousValues);
} else {
// Return arbitrarily large error if conditional is null
// Return arbitrarily large logProbability if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
}
/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
double GaussianMixture::logProbability(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous());
return conditional->logProbability(values.continuous());
}
} // namespace gtsam

View File

@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
const Conditionals &conditionals() const;
/**
* @brief Compute error of the GaussianMixture as a tree.
* @brief Compute logProbability of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error.
* as the conditionals, and leaf values as the logProbability.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
* @brief Compute the logProbability of this Gaussian Mixture given the
* continuous values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;
double logProbability(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;

View File

@ -255,11 +255,6 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize();
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(-error(values));
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
@ -296,23 +291,28 @@ HybridValues HybridBayesNet::sample() const {
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error.
error_tree = error_tree + gm->error(continuousValues);
// If conditional is hybrid, select based on assignment and compute
// logProbability.
error_tree = error_tree + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) error and add it to the error_tree
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
// If continuous, get the (double) logProbability and add it to the
// error_tree
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the logProbability
// tree.
error_tree = error_tree.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
// TODO(dellaert): if discrete, we need to add error in the right branch?
// TODO(dellaert): if discrete, we need to add logProbability in the right
// branch?
continue;
}
}
@ -321,10 +321,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
return error_tree.apply([](double error) { return exp(-error); });
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return tree.apply([](double log) { return exp(log); });
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(logProbability(values));
}
/* ************************************************************************* */

View File

@ -187,8 +187,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);
using Base::error; // Expose error(const HybridValues&) method..
/**
* @brief Compute conditional error for each discrete assignment,
* and return as a tree.
@ -196,7 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
using BayesNet::logProbability; // expose HybridValues version
/**
* @brief Compute unnormalized probability q(μ|M),
@ -208,7 +209,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
AlgebraicDecisionTree<Key> evaluate(
const VectorValues &continuousValues) const;
/**

View File

@ -122,18 +122,18 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
}
/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gm = asMixture()) {
return gm->error(values);
}
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
return gc->logProbability(values.continuous());
}
if (auto gm = asMixture()) {
return gm->logProbability(values);
}
if (auto dc = asDiscrete()) {
return -log((*dc)(values.discrete()));
return dc->logProbability(values.discrete());
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
"HybridConditional::logProbability: conditional type not handled");
}
} // namespace gtsam

View File

@ -176,8 +176,8 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() const { return inner_; }
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;
/// Return the logProbability of the underlying conditional.
double logProbability(const HybridValues& values) const override;
/// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const {

View File

@ -116,12 +116,12 @@ TEST(GaussianMixture, Error) {
VectorValues values;
values.insert(X(1), Vector2::Ones());
values.insert(X(2), Vector2::Zero());
auto error_tree = mixture.error(values);
auto error_tree = mixture.logProbability(values);
// Check result.
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> leaves = {conditional0->error(values),
conditional1->error(values)};
std::vector<double> leaves = {conditional0->logProbability(values),
conditional1->logProbability(values)};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
@ -129,11 +129,11 @@ TEST(GaussianMixture, Error) {
// Regression for non-tree version.
DiscreteValues assignment;
assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(conditional0->error(values),
mixture.error({values, assignment}), 1e-8);
EXPECT_DOUBLES_EQUAL(conditional0->logProbability(values),
mixture.logProbability({values, assignment}), 1e-8);
assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(conditional1->error(values),
mixture.error({values, assignment}), 1e-8);
EXPECT_DOUBLES_EQUAL(conditional1->logProbability(values),
mixture.logProbability({values, assignment}), 1e-8);
}
/* ************************************************************************* */

View File

@ -64,10 +64,10 @@ TEST(HybridBayesNet, Add) {
// Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet;
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6"));
HybridValues values;
values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
}
/* ****************************************************************************/
@ -207,7 +207,7 @@ TEST(HybridBayesNet, Optimize) {
/* ****************************************************************************/
// Test Bayes net error
TEST(HybridBayesNet, Error) {
TEST(HybridBayesNet, logProbability) {
Switching s(3);
HybridBayesNet::shared_ptr hybridBayesNet =
@ -215,42 +215,49 @@ TEST(HybridBayesNet, Error) {
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous());
auto error_tree = hybridBayesNet->logProbability(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {-4.1609374, -4.1706942, -4.141568, -4.1609374};
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
// Error on pruned Bayes net
// logProbability on pruned Bayes net
auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
std::vector<double> pruned_leaves = {2e50, -4.1706942, 2e50, -4.1609374};
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves);
// regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
// Verify error computation and check for specific error value
// Verify logProbability computation and check for specific logProbability
// value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
double error = 0;
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues);
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues);
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues);
double logProbability = 0;
logProbability +=
hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
// TODO(dellaert): the discrete errors are not added in error tree!
EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9);
error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
// TODO(dellaert): the discrete errors are not added in logProbability tree!
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
1e-9);
logProbability +=
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
logProbability +=
hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values);
EXPECT_DOUBLES_EQUAL(logProbability,
hybridBayesNet->logProbability(hybridValues), 1e-9);
}
/* ****************************************************************************/