Moved factor graph error(HybridValues) to FactorGraph base class.

release/4.3a0
Frank Dellaert 2023-01-08 18:19:12 -08:00
parent b4706bec85
commit 83bae7d701
10 changed files with 51 additions and 92 deletions

View File

@ -257,29 +257,7 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete();
const VectorValues &continuousValues = values.continuous();
double error = 0.0, probability = 1.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
error += component->error(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
error += gc->error(continuousValues);
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
}
}
return probability * exp(-error);
return exp(-error(values));
}
/* ************************************************************************* */
@ -317,12 +295,6 @@ HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************* */
double HybridBayesNet::error(const HybridValues &values) const {
GaussianBayesNet gbn = choose(values.discrete());
return gbn.error(values.continuous());
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
@ -332,19 +304,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error.
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);
error_tree = error_tree + conditional_error;
error_tree = error_tree + gm->error(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, get the (double) error
// and add it to the error_tree
// If continuous, get the (double) error and add it to the error_tree
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, we skip.
// TODO(dellaert): if discrete, we need to add error in the right branch?
continue;
}
}

View File

@ -187,14 +187,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);
/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const;
using Base::error; // Expose error(const HybridValues&) method..
/**
* @brief Compute conditional error for each discrete assignment,

View File

@ -463,24 +463,6 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree;
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
error += hf->error(values.continuous());
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
// TODO(dellaert): needs to change when we discard other wrappers.
error += hf->error(values);
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
error -= log((*dtf)(values.discrete()));
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
}
}
return error;
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
double error = this->error(values);

View File

@ -145,6 +145,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface
/// @{
using Base::error; // Expose error(const HybridValues&) method..
/**
* @brief Compute error for each discrete assignment,
* and return as a tree.
@ -156,14 +158,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @return double
*/
double error(const HybridValues& values) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.

View File

@ -55,12 +55,18 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
: Base(graph) {}
/// @}
/// @name Constructors
/// @{
/// Print the factor graph.
void print(
const std::string& s = "HybridNonlinearFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/**
* @brief Linearize all the continuous factors in the
* HybridNonlinearFactorGraph.
@ -70,6 +76,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/
HybridGaussianFactorGraph::shared_ptr linearize(
const Values& continuousValues) const;
/// @}
};
template <>

View File

@ -212,6 +212,7 @@ TEST(HybridBayesNet, Error) {
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous());
@ -235,26 +236,21 @@ TEST(HybridBayesNet, Error) {
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
// Verify error computation and check for specific error value
DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
double error = 0;
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues);
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues);
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues);
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->at(idx)->asMixture()->error(
{delta.continuous(), discrete_values});
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error =
hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous());
total_error += error;
}
}
// TODO(dellaert): the discrete errors are not added in error tree!
EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9);
error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
}
/* ****************************************************************************/

View File

@ -60,12 +60,14 @@ TEST(HybridFactorGraph, GaussianFactorGraph) {
Values linearizationPoint;
linearizationPoint.insert<double>(X(0), 0);
// Linearize the factor graph.
HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint);
EXPECT_LONGS_EQUAL(1, ghfg.size());
// Add a factor to the GaussianFactorGraph
ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5)));
EXPECT_LONGS_EQUAL(2, ghfg.size());
// Check that the error is the same for the nonlinear values.
const VectorValues zero{{X(0), Vector1(0)}};
const HybridValues hybridValues{zero, {}, linearizationPoint};
EXPECT_DOUBLES_EQUAL(fg.error(hybridValues), ghfg.error(hybridValues), 1e-9);
}
/***************************************************************************

View File

@ -61,6 +61,16 @@ bool FactorGraph<FACTOR>::equals(const This& fg, double tol) const {
return true;
}
/* ************************************************************************ */
template <class FACTOR>
double FactorGraph<FACTOR>::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
error += f->error(values);
}
return error;
}
/* ************************************************************************* */
template <class FACTOR>
size_t FactorGraph<FACTOR>::nrFactors() const {

View File

@ -47,6 +47,8 @@ typedef FastVector<FactorIndex> FactorIndices;
template <class CLIQUE>
class BayesTree;
class HybridValues;
/** Helper */
template <class C>
class CRefCallPushBack {
@ -359,6 +361,9 @@ class FactorGraph {
/** Get the last factor */
sharedFactor back() const { return factors_.back(); }
/** Add error for all factors. */
double error(const HybridValues &values) const;
/// @}
/// @name Modifying Factor Graphs (imperative, discouraged)
/// @{

View File

@ -145,6 +145,8 @@ namespace gtsam {
return exp(logNormalizationConstant());
}
using Base::error; // Expose error(const HybridValues&) method..
/**
* Calculate error(x) == -log(evaluate()) for given values `x`:
* - GaussianFactor::error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)