Moved factor graph error(HybridValues) to FactorGraph base class.

release/4.3a0
Frank Dellaert 2023-01-08 18:19:12 -08:00
parent b4706bec85
commit 83bae7d701
10 changed files with 51 additions and 92 deletions

View File

@ -257,29 +257,7 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const { double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete(); return exp(-error(values));
const VectorValues &continuousValues = values.continuous();
double error = 0.0, probability = 1.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
error += component->error(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
error += gc->error(continuousValues);
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
}
}
return probability * exp(-error);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -317,12 +295,6 @@ HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator); return sample(&kRandomNumberGenerator);
} }
/* ************************************************************************* */
double HybridBayesNet::error(const HybridValues &values) const {
GaussianBayesNet gbn = choose(values.discrete());
return gbn.error(values.continuous());
}
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error( AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
@ -332,19 +304,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error. // If conditional is hybrid, select based on assignment and compute error.
AlgebraicDecisionTree<Key> conditional_error = error_tree = error_tree + gm->error(continuousValues);
gm->error(continuousValues);
error_tree = error_tree + conditional_error;
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous only, get the (double) error // If continuous, get the (double) error and add it to the error_tree
// and add it to the error_tree
double error = gc->error(continuousValues); double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree. // Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (auto dc = conditional->asDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, we skip. // TODO(dellaert): if discrete, we need to add error in the right branch?
continue; continue;
} }
} }

View File

@ -187,14 +187,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves); HybridBayesNet prune(size_t maxNrLeaves);
/** using Base::error; // Expose error(const HybridValues&) method..
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const;
/** /**
* @brief Compute conditional error for each discrete assignment, * @brief Compute conditional error for each discrete assignment,

View File

@ -463,24 +463,6 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree; return error_tree;
} }
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
error += hf->error(values.continuous());
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
// TODO(dellaert): needs to change when we discard other wrappers.
error += hf->error(values);
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
error -= log((*dtf)(values.discrete()));
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
}
}
return error;
}
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
double error = this->error(values); double error = this->error(values);

View File

@ -145,6 +145,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
using Base::error; // Expose error(const HybridValues&) method..
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,
* and return as a tree. * and return as a tree.
@ -156,14 +158,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const; AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @return double
*/
double error(const HybridValues& values) const;
/** /**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree. * for each discrete assignment, and return as a tree.

View File

@ -55,12 +55,18 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
: Base(graph) {} : Base(graph) {}
/// @} /// @}
/// @name Constructors
/// @{
/// Print the factor graph. /// Print the factor graph.
void print( void print(
const std::string& s = "HybridNonlinearFactorGraph", const std::string& s = "HybridNonlinearFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/** /**
* @brief Linearize all the continuous factors in the * @brief Linearize all the continuous factors in the
* HybridNonlinearFactorGraph. * HybridNonlinearFactorGraph.
@ -70,6 +76,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/ */
HybridGaussianFactorGraph::shared_ptr linearize( HybridGaussianFactorGraph::shared_ptr linearize(
const Values& continuousValues) const; const Values& continuousValues) const;
/// @}
}; };
template <> template <>

View File

@ -212,6 +212,7 @@ TEST(HybridBayesNet, Error) {
HybridBayesNet::shared_ptr hybridBayesNet = HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(); s.linearizedFactorGraph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous()); auto error_tree = hybridBayesNet->error(delta.continuous());
@ -235,26 +236,21 @@ TEST(HybridBayesNet, Error) {
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6)); EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
// Verify error computation and check for specific error value // Verify error computation and check for specific error value
DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
double error = 0;
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues);
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues);
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues);
double total_error = 0; // TODO(dellaert): the discrete errors are not added in error tree!
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9);
if (hybridBayesNet->at(idx)->isHybrid()) { EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9);
double error = hybridBayesNet->at(idx)->asMixture()->error(
{delta.continuous(), discrete_values}); error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
total_error += error; error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
} else if (hybridBayesNet->at(idx)->isContinuous()) { EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
double error =
hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous());
total_error += error;
}
}
EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -60,12 +60,14 @@ TEST(HybridFactorGraph, GaussianFactorGraph) {
Values linearizationPoint; Values linearizationPoint;
linearizationPoint.insert<double>(X(0), 0); linearizationPoint.insert<double>(X(0), 0);
// Linearize the factor graph.
HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint); HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint);
EXPECT_LONGS_EQUAL(1, ghfg.size());
// Add a factor to the GaussianFactorGraph // Check that the error is the same for the nonlinear values.
ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5))); const VectorValues zero{{X(0), Vector1(0)}};
const HybridValues hybridValues{zero, {}, linearizationPoint};
EXPECT_LONGS_EQUAL(2, ghfg.size()); EXPECT_DOUBLES_EQUAL(fg.error(hybridValues), ghfg.error(hybridValues), 1e-9);
} }
/*************************************************************************** /***************************************************************************

View File

@ -61,6 +61,16 @@ bool FactorGraph<FACTOR>::equals(const This& fg, double tol) const {
return true; return true;
} }
/* ************************************************************************ */
template <class FACTOR>
double FactorGraph<FACTOR>::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
error += f->error(values);
}
return error;
}
/* ************************************************************************* */ /* ************************************************************************* */
template <class FACTOR> template <class FACTOR>
size_t FactorGraph<FACTOR>::nrFactors() const { size_t FactorGraph<FACTOR>::nrFactors() const {

View File

@ -47,6 +47,8 @@ typedef FastVector<FactorIndex> FactorIndices;
template <class CLIQUE> template <class CLIQUE>
class BayesTree; class BayesTree;
class HybridValues;
/** Helper */ /** Helper */
template <class C> template <class C>
class CRefCallPushBack { class CRefCallPushBack {
@ -359,6 +361,9 @@ class FactorGraph {
/** Get the last factor */ /** Get the last factor */
sharedFactor back() const { return factors_.back(); } sharedFactor back() const { return factors_.back(); }
/** Add error for all factors. */
double error(const HybridValues &values) const;
/// @} /// @}
/// @name Modifying Factor Graphs (imperative, discouraged) /// @name Modifying Factor Graphs (imperative, discouraged)
/// @{ /// @{

View File

@ -145,6 +145,8 @@ namespace gtsam {
return exp(logNormalizationConstant()); return exp(logNormalizationConstant());
} }
using Base::error; // Expose error(const HybridValues&) method..
/** /**
* Calculate error(x) == -log(evaluate()) for given values `x`: * Calculate error(x) == -log(evaluate()) for given values `x`:
* - GaussianFactor::error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) * - GaussianFactor::error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)