Endowed hybrid with logProbability
parent
11ef99b3f0
commit
426a49dc72
|
@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to calculate to double error value from GaussianConditional.
|
||||
// functor to calculate to double logProbability value from
|
||||
// GaussianConditional.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues);
|
||||
return conditional->logProbability(continuousValues);
|
||||
} else {
|
||||
// Return arbitrarily large error if conditional is null
|
||||
// Return arbitrarily large logProbability if conditional is null
|
||||
// Conditional is null if it is pruned out.
|
||||
return 1e50;
|
||||
}
|
||||
|
@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const HybridValues &values) const {
|
||||
double GaussianMixture::logProbability(const HybridValues &values) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->error(values.continuous());
|
||||
return conditional->logProbability(values.continuous());
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
|
|||
const Conditionals &conditionals() const;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the GaussianMixture as a tree.
|
||||
* @brief Compute logProbability of the GaussianMixture as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the conditionals, and leaf values as the error.
|
||||
* as the conditionals, and leaf values as the logProbability.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> logProbability(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
* @brief Compute the logProbability of this Gaussian Mixture given the
|
||||
* continuous values and a discrete assignment.
|
||||
*
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
double error(const HybridValues &values) const override;
|
||||
double logProbability(const HybridValues &values) const override;
|
||||
|
||||
// /// Calculate probability density for given values `x`.
|
||||
// double evaluate(const HybridValues &values) const;
|
||||
|
|
|
@ -255,11 +255,6 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
|||
return gbn.optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||
return exp(-error(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::sample(const HybridValues &given,
|
||||
std::mt19937_64 *rng) const {
|
||||
|
@ -296,23 +291,28 @@ HybridValues HybridBayesNet::sample() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, select based on assignment and compute error.
|
||||
error_tree = error_tree + gm->error(continuousValues);
|
||||
// If conditional is hybrid, select based on assignment and compute
|
||||
// logProbability.
|
||||
error_tree = error_tree + gm->logProbability(continuousValues);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the (double) error and add it to the error_tree
|
||||
double error = gc->error(continuousValues);
|
||||
// Add the computed error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
// If continuous, get the (double) logProbability and add it to the
|
||||
// error_tree
|
||||
double logProbability = gc->logProbability(continuousValues);
|
||||
// Add the computed logProbability to every leaf of the logProbability
|
||||
// tree.
|
||||
error_tree = error_tree.apply([logProbability](double leaf_value) {
|
||||
return leaf_value + logProbability;
|
||||
});
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// TODO(dellaert): if discrete, we need to add error in the right branch?
|
||||
// TODO(dellaert): if discrete, we need to add logProbability in the right
|
||||
// branch?
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -321,10 +321,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
return error_tree.apply([](double error) { return exp(-error); });
|
||||
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
|
||||
return tree.apply([](double log) { return exp(log); });
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||
return exp(logProbability(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -187,8 +187,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves);
|
||||
|
||||
using Base::error; // Expose error(const HybridValues&) method..
|
||||
|
||||
/**
|
||||
* @brief Compute conditional error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
|
@ -196,7 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> logProbability(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
using BayesNet::logProbability; // expose HybridValues version
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability q(μ|M),
|
||||
|
@ -208,7 +209,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
AlgebraicDecisionTree<Key> evaluate(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
|
|
|
@ -122,18 +122,18 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridConditional::error(const HybridValues &values) const {
|
||||
if (auto gm = asMixture()) {
|
||||
return gm->error(values);
|
||||
}
|
||||
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return gc->error(values.continuous());
|
||||
return gc->logProbability(values.continuous());
|
||||
}
|
||||
if (auto gm = asMixture()) {
|
||||
return gm->logProbability(values);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
return -log((*dc)(values.discrete()));
|
||||
return dc->logProbability(values.discrete());
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
"HybridConditional::logProbability: conditional type not handled");
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -176,8 +176,8 @@ class GTSAM_EXPORT HybridConditional
|
|||
/// Get the type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying conditional.
|
||||
double error(const HybridValues& values) const override;
|
||||
/// Return the logProbability of the underlying conditional.
|
||||
double logProbability(const HybridValues& values) const override;
|
||||
|
||||
/// Check if VectorValues `measurements` contains all frontal keys.
|
||||
bool frontalsIn(const VectorValues& measurements) const {
|
||||
|
|
|
@ -116,12 +116,12 @@ TEST(GaussianMixture, Error) {
|
|||
VectorValues values;
|
||||
values.insert(X(1), Vector2::Ones());
|
||||
values.insert(X(2), Vector2::Zero());
|
||||
auto error_tree = mixture.error(values);
|
||||
auto error_tree = mixture.logProbability(values);
|
||||
|
||||
// Check result.
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> leaves = {conditional0->error(values),
|
||||
conditional1->error(values)};
|
||||
std::vector<double> leaves = {conditional0->logProbability(values),
|
||||
conditional1->logProbability(values)};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||
|
@ -129,11 +129,11 @@ TEST(GaussianMixture, Error) {
|
|||
// Regression for non-tree version.
|
||||
DiscreteValues assignment;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(conditional0->error(values),
|
||||
mixture.error({values, assignment}), 1e-8);
|
||||
EXPECT_DOUBLES_EQUAL(conditional0->logProbability(values),
|
||||
mixture.logProbability({values, assignment}), 1e-8);
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT_DOUBLES_EQUAL(conditional1->error(values),
|
||||
mixture.error({values, assignment}), 1e-8);
|
||||
EXPECT_DOUBLES_EQUAL(conditional1->logProbability(values),
|
||||
mixture.logProbability({values, assignment}), 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -64,10 +64,10 @@ TEST(HybridBayesNet, Add) {
|
|||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
||||
bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6"));
|
||||
HybridValues values;
|
||||
values.insert(asiaKey, 0);
|
||||
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -207,7 +207,7 @@ TEST(HybridBayesNet, Optimize) {
|
|||
|
||||
/* ****************************************************************************/
|
||||
// Test Bayes net error
|
||||
TEST(HybridBayesNet, Error) {
|
||||
TEST(HybridBayesNet, logProbability) {
|
||||
Switching s(3);
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
|
@ -215,42 +215,49 @@ TEST(HybridBayesNet, Error) {
|
|||
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = hybridBayesNet->error(delta.continuous());
|
||||
auto error_tree = hybridBayesNet->logProbability(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {-4.1609374, -4.1706942, -4.141568, -4.1609374};
|
||||
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||
|
||||
// Error on pruned Bayes net
|
||||
// logProbability on pruned Bayes net
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
auto pruned_error_tree = prunedBayesNet.error(delta.continuous());
|
||||
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
|
||||
|
||||
std::vector<double> pruned_leaves = {2e50, -4.1706942, 2e50, -4.1609374};
|
||||
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
|
||||
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
||||
pruned_leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
|
||||
|
||||
// Verify error computation and check for specific error value
|
||||
// Verify logProbability computation and check for specific logProbability
|
||||
// value
|
||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||
double error = 0;
|
||||
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues);
|
||||
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues);
|
||||
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues);
|
||||
double logProbability = 0;
|
||||
logProbability +=
|
||||
hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues);
|
||||
logProbability +=
|
||||
hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues);
|
||||
logProbability +=
|
||||
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
|
||||
|
||||
// TODO(dellaert): the discrete errors are not added in error tree!
|
||||
EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9);
|
||||
|
||||
error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
|
||||
error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
|
||||
EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
|
||||
// TODO(dellaert): the discrete errors are not added in logProbability tree!
|
||||
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
|
||||
1e-9);
|
||||
|
||||
logProbability +=
|
||||
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
|
||||
logProbability +=
|
||||
hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values);
|
||||
EXPECT_DOUBLES_EQUAL(logProbability,
|
||||
hybridBayesNet->logProbability(hybridValues), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
Loading…
Reference in New Issue