Endowed hybrid with logProbability
parent
11ef99b3f0
commit
426a49dc72
|
@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
// functor to calculate to double error value from GaussianConditional.
|
// functor to calculate to double logProbability value from
|
||||||
|
// GaussianConditional.
|
||||||
auto errorFunc =
|
auto errorFunc =
|
||||||
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
return conditional->error(continuousValues);
|
return conditional->logProbability(continuousValues);
|
||||||
} else {
|
} else {
|
||||||
// Return arbitrarily large error if conditional is null
|
// Return arbitrarily large logProbability if conditional is null
|
||||||
// Conditional is null if it is pruned out.
|
// Conditional is null if it is pruned out.
|
||||||
return 1e50;
|
return 1e50;
|
||||||
}
|
}
|
||||||
|
@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const HybridValues &values) const {
|
double GaussianMixture::logProbability(const HybridValues &values) const {
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->error(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
const Conditionals &conditionals() const;
|
const Conditionals &conditionals() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute error of the GaussianMixture as a tree.
|
* @brief Compute logProbability of the GaussianMixture as a tree.
|
||||||
*
|
*
|
||||||
* @param continuousValues The continuous VectorValues.
|
* @param continuousValues The continuous VectorValues.
|
||||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||||
* as the conditionals, and leaf values as the error.
|
* as the conditionals, and leaf values as the logProbability.
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
AlgebraicDecisionTree<Key> logProbability(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
* @brief Compute the logProbability of this Gaussian Mixture given the
|
||||||
* values and a discrete assignment.
|
* continuous values and a discrete assignment.
|
||||||
*
|
*
|
||||||
* @param values Continuous values and discrete assignment.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues &values) const override;
|
double logProbability(const HybridValues &values) const override;
|
||||||
|
|
||||||
// /// Calculate probability density for given values `x`.
|
// /// Calculate probability density for given values `x`.
|
||||||
// double evaluate(const HybridValues &values) const;
|
// double evaluate(const HybridValues &values) const;
|
||||||
|
|
|
@ -255,11 +255,6 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
return gbn.optimize();
|
return gbn.optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
|
||||||
return exp(-error(values));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::sample(const HybridValues &given,
|
HybridValues HybridBayesNet::sample(const HybridValues &given,
|
||||||
std::mt19937_64 *rng) const {
|
std::mt19937_64 *rng) const {
|
||||||
|
@ -296,23 +291,28 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
// If conditional is hybrid, select based on assignment and compute error.
|
// If conditional is hybrid, select based on assignment and compute
|
||||||
error_tree = error_tree + gm->error(continuousValues);
|
// logProbability.
|
||||||
|
error_tree = error_tree + gm->logProbability(continuousValues);
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// If continuous, get the (double) error and add it to the error_tree
|
// If continuous, get the (double) logProbability and add it to the
|
||||||
double error = gc->error(continuousValues);
|
// error_tree
|
||||||
// Add the computed error to every leaf of the error tree.
|
double logProbability = gc->logProbability(continuousValues);
|
||||||
error_tree = error_tree.apply(
|
// Add the computed logProbability to every leaf of the logProbability
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
// tree.
|
||||||
|
error_tree = error_tree.apply([logProbability](double leaf_value) {
|
||||||
|
return leaf_value + logProbability;
|
||||||
|
});
|
||||||
} else if (auto dc = conditional->asDiscrete()) {
|
} else if (auto dc = conditional->asDiscrete()) {
|
||||||
// TODO(dellaert): if discrete, we need to add error in the right branch?
|
// TODO(dellaert): if discrete, we need to add logProbability in the right
|
||||||
|
// branch?
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -321,10 +321,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
|
||||||
return error_tree.apply([](double error) { return exp(-error); });
|
return tree.apply([](double log) { return exp(log); });
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||||
|
return exp(logProbability(values));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -187,8 +187,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
HybridBayesNet prune(size_t maxNrLeaves);
|
HybridBayesNet prune(size_t maxNrLeaves);
|
||||||
|
|
||||||
using Base::error; // Expose error(const HybridValues&) method..
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute conditional error for each discrete assignment,
|
* @brief Compute conditional error for each discrete assignment,
|
||||||
* and return as a tree.
|
* and return as a tree.
|
||||||
|
@ -196,7 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param continuousValues Continuous values at which to compute the error.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
AlgebraicDecisionTree<Key> logProbability(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
using BayesNet::logProbability; // expose HybridValues version
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute unnormalized probability q(μ|M),
|
* @brief Compute unnormalized probability q(μ|M),
|
||||||
|
@ -208,7 +209,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* probability.
|
* probability.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> probPrime(
|
AlgebraicDecisionTree<Key> evaluate(
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -122,18 +122,18 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridConditional::error(const HybridValues &values) const {
|
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
if (auto gm = asMixture()) {
|
|
||||||
return gm->error(values);
|
|
||||||
}
|
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->error(values.continuous());
|
return gc->logProbability(values.continuous());
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->logProbability(values);
|
||||||
}
|
}
|
||||||
if (auto dc = asDiscrete()) {
|
if (auto dc = asDiscrete()) {
|
||||||
return -log((*dc)(values.discrete()));
|
return dc->logProbability(values.discrete());
|
||||||
}
|
}
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybridConditional::error: conditional type not handled");
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -176,8 +176,8 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner() const { return inner_; }
|
boost::shared_ptr<Factor> inner() const { return inner_; }
|
||||||
|
|
||||||
/// Return the error of the underlying conditional.
|
/// Return the logProbability of the underlying conditional.
|
||||||
double error(const HybridValues& values) const override;
|
double logProbability(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Check if VectorValues `measurements` contains all frontal keys.
|
/// Check if VectorValues `measurements` contains all frontal keys.
|
||||||
bool frontalsIn(const VectorValues& measurements) const {
|
bool frontalsIn(const VectorValues& measurements) const {
|
||||||
|
|
|
@ -116,12 +116,12 @@ TEST(GaussianMixture, Error) {
|
||||||
VectorValues values;
|
VectorValues values;
|
||||||
values.insert(X(1), Vector2::Ones());
|
values.insert(X(1), Vector2::Ones());
|
||||||
values.insert(X(2), Vector2::Zero());
|
values.insert(X(2), Vector2::Zero());
|
||||||
auto error_tree = mixture.error(values);
|
auto error_tree = mixture.logProbability(values);
|
||||||
|
|
||||||
// Check result.
|
// Check result.
|
||||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||||
std::vector<double> leaves = {conditional0->error(values),
|
std::vector<double> leaves = {conditional0->logProbability(values),
|
||||||
conditional1->error(values)};
|
conditional1->logProbability(values)};
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||||
|
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||||
|
@ -129,11 +129,11 @@ TEST(GaussianMixture, Error) {
|
||||||
// Regression for non-tree version.
|
// Regression for non-tree version.
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
assignment[M(1)] = 0;
|
assignment[M(1)] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(conditional0->error(values),
|
EXPECT_DOUBLES_EQUAL(conditional0->logProbability(values),
|
||||||
mixture.error({values, assignment}), 1e-8);
|
mixture.logProbability({values, assignment}), 1e-8);
|
||||||
assignment[M(1)] = 1;
|
assignment[M(1)] = 1;
|
||||||
EXPECT_DOUBLES_EQUAL(conditional1->error(values),
|
EXPECT_DOUBLES_EQUAL(conditional1->logProbability(values),
|
||||||
mixture.error({values, assignment}), 1e-8);
|
mixture.logProbability({values, assignment}), 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -64,10 +64,10 @@ TEST(HybridBayesNet, Add) {
|
||||||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
// Test evaluate for a pure discrete Bayes net P(Asia).
|
||||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
HybridBayesNet bayesNet;
|
HybridBayesNet bayesNet;
|
||||||
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6"));
|
||||||
HybridValues values;
|
HybridValues values;
|
||||||
values.insert(asiaKey, 0);
|
values.insert(asiaKey, 0);
|
||||||
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
@ -207,7 +207,7 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test Bayes net error
|
// Test Bayes net error
|
||||||
TEST(HybridBayesNet, Error) {
|
TEST(HybridBayesNet, logProbability) {
|
||||||
Switching s(3);
|
Switching s(3);
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
|
@ -215,42 +215,49 @@ TEST(HybridBayesNet, Error) {
|
||||||
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
auto error_tree = hybridBayesNet->error(delta.continuous());
|
auto error_tree = hybridBayesNet->logProbability(delta.continuous());
|
||||||
|
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||||
std::vector<double> leaves = {-4.1609374, -4.1706942, -4.141568, -4.1609374};
|
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||||
|
|
||||||
// Error on pruned Bayes net
|
// logProbability on pruned Bayes net
|
||||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||||
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
|
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
|
||||||
|
|
||||||
std::vector<double> pruned_leaves = {2e50, -4.1706942, 2e50, -4.1609374};
|
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
|
||||||
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
||||||
pruned_leaves);
|
pruned_leaves);
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
|
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
|
||||||
|
|
||||||
// Verify error computation and check for specific error value
|
// Verify logProbability computation and check for specific logProbability
|
||||||
|
// value
|
||||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||||
double error = 0;
|
double logProbability = 0;
|
||||||
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues);
|
logProbability +=
|
||||||
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues);
|
hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues);
|
||||||
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues);
|
logProbability +=
|
||||||
|
hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues);
|
||||||
|
logProbability +=
|
||||||
|
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
|
||||||
|
|
||||||
// TODO(dellaert): the discrete errors are not added in error tree!
|
// TODO(dellaert): the discrete errors are not added in logProbability tree!
|
||||||
EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
|
||||||
|
1e-9);
|
||||||
error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
|
|
||||||
error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
|
|
||||||
EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
|
|
||||||
|
|
||||||
|
logProbability +=
|
||||||
|
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
|
||||||
|
logProbability +=
|
||||||
|
hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values);
|
||||||
|
EXPECT_DOUBLES_EQUAL(logProbability,
|
||||||
|
hybridBayesNet->logProbability(hybridValues), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
Loading…
Reference in New Issue