Moved factor graph error(HybridValues) to FactorGraph base class.
parent
b4706bec85
commit
83bae7d701
|
|
@ -257,29 +257,7 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||
const DiscreteValues &discreteValues = values.discrete();
|
||||
const VectorValues &continuousValues = values.continuous();
|
||||
|
||||
double error = 0.0, probability = 1.0;
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
// TODO: should be delegated to derived classes.
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
const auto component = (*gm)(discreteValues);
|
||||
error += component->error(continuousValues);
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, evaluate the probability and multiply.
|
||||
error += gc->error(continuousValues);
|
||||
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// Conditional is discrete-only, so return its probability.
|
||||
probability *= dc->operator()(discreteValues);
|
||||
}
|
||||
}
|
||||
|
||||
return probability * exp(-error);
|
||||
return exp(-error(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -317,12 +295,6 @@ HybridValues HybridBayesNet::sample() const {
|
|||
return sample(&kRandomNumberGenerator);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double HybridBayesNet::error(const HybridValues &values) const {
|
||||
GaussianBayesNet gbn = choose(values.discrete());
|
||||
return gbn.error(values.continuous());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
|
|
@ -332,19 +304,15 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, select based on assignment and compute error.
|
||||
AlgebraicDecisionTree<Key> conditional_error =
|
||||
gm->error(continuousValues);
|
||||
|
||||
error_tree = error_tree + conditional_error;
|
||||
error_tree = error_tree + gm->error(continuousValues);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
// If continuous, get the (double) error and add it to the error_tree
|
||||
double error = gc->error(continuousValues);
|
||||
// Add the computed error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// Conditional is discrete-only, we skip.
|
||||
// TODO(dellaert): if discrete, we need to add error in the right branch?
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -187,14 +187,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves);
|
||||
|
||||
/**
|
||||
* @brief 0.5 * sum of squared Mahalanobis distances
|
||||
* for a specific discrete assignment.
|
||||
*
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
double error(const HybridValues &values) const;
|
||||
using Base::error; // Expose error(const HybridValues&) method..
|
||||
|
||||
/**
|
||||
* @brief Compute conditional error for each discrete assignment,
|
||||
|
|
|
|||
|
|
@ -463,24 +463,6 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
return error_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (auto &f : factors_) {
|
||||
if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||
error += hf->error(values.continuous());
|
||||
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
// TODO(dellaert): needs to change when we discard other wrappers.
|
||||
error += hf->error(values);
|
||||
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
error -= log((*dtf)(values.discrete()));
|
||||
} else {
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
|
||||
}
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||
double error = this->error(values);
|
||||
|
|
|
|||
|
|
@ -145,6 +145,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
using Base::error; // Expose error(const HybridValues&) method..
|
||||
|
||||
/**
|
||||
* @brief Compute error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
|
|
@ -156,14 +158,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute error given a continuous vector values
|
||||
* and a discrete assignment.
|
||||
*
|
||||
* @return double
|
||||
*/
|
||||
double error(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
* for each discrete assignment, and return as a tree.
|
||||
|
|
|
|||
|
|
@ -55,12 +55,18 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
: Base(graph) {}
|
||||
|
||||
/// @}
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Print the factor graph.
|
||||
void print(
|
||||
const std::string& s = "HybridNonlinearFactorGraph",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Linearize all the continuous factors in the
|
||||
* HybridNonlinearFactorGraph.
|
||||
|
|
@ -70,6 +76,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
*/
|
||||
HybridGaussianFactorGraph::shared_ptr linearize(
|
||||
const Values& continuousValues) const;
|
||||
/// @}
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ TEST(HybridBayesNet, Error) {
|
|||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = hybridBayesNet->error(delta.continuous());
|
||||
|
|
@ -235,26 +236,21 @@ TEST(HybridBayesNet, Error) {
|
|||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
|
||||
|
||||
// Verify error computation and check for specific error value
|
||||
DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||
double error = 0;
|
||||
error += hybridBayesNet->at(0)->asMixture()->error(hybridValues);
|
||||
error += hybridBayesNet->at(1)->asMixture()->error(hybridValues);
|
||||
error += hybridBayesNet->at(2)->asMixture()->error(hybridValues);
|
||||
|
||||
double total_error = 0;
|
||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||
if (hybridBayesNet->at(idx)->isHybrid()) {
|
||||
double error = hybridBayesNet->at(idx)->asMixture()->error(
|
||||
{delta.continuous(), discrete_values});
|
||||
total_error += error;
|
||||
} else if (hybridBayesNet->at(idx)->isContinuous()) {
|
||||
double error =
|
||||
hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous());
|
||||
total_error += error;
|
||||
}
|
||||
}
|
||||
// TODO(dellaert): the discrete errors are not added in error tree!
|
||||
EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9);
|
||||
|
||||
error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values);
|
||||
error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values);
|
||||
EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9);
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
|
||||
1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -60,12 +60,14 @@ TEST(HybridFactorGraph, GaussianFactorGraph) {
|
|||
Values linearizationPoint;
|
||||
linearizationPoint.insert<double>(X(0), 0);
|
||||
|
||||
// Linearize the factor graph.
|
||||
HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint);
|
||||
EXPECT_LONGS_EQUAL(1, ghfg.size());
|
||||
|
||||
// Add a factor to the GaussianFactorGraph
|
||||
ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5)));
|
||||
|
||||
EXPECT_LONGS_EQUAL(2, ghfg.size());
|
||||
// Check that the error is the same for the nonlinear values.
|
||||
const VectorValues zero{{X(0), Vector1(0)}};
|
||||
const HybridValues hybridValues{zero, {}, linearizationPoint};
|
||||
EXPECT_DOUBLES_EQUAL(fg.error(hybridValues), ghfg.error(hybridValues), 1e-9);
|
||||
}
|
||||
|
||||
/***************************************************************************
|
||||
|
|
|
|||
|
|
@ -61,6 +61,16 @@ bool FactorGraph<FACTOR>::equals(const This& fg, double tol) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
template <class FACTOR>
|
||||
double FactorGraph<FACTOR>::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (auto &f : factors_) {
|
||||
error += f->error(values);
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR>
|
||||
size_t FactorGraph<FACTOR>::nrFactors() const {
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ typedef FastVector<FactorIndex> FactorIndices;
|
|||
template <class CLIQUE>
|
||||
class BayesTree;
|
||||
|
||||
class HybridValues;
|
||||
|
||||
/** Helper */
|
||||
template <class C>
|
||||
class CRefCallPushBack {
|
||||
|
|
@ -359,6 +361,9 @@ class FactorGraph {
|
|||
/** Get the last factor */
|
||||
sharedFactor back() const { return factors_.back(); }
|
||||
|
||||
/** Add error for all factors. */
|
||||
double error(const HybridValues &values) const;
|
||||
|
||||
/// @}
|
||||
/// @name Modifying Factor Graphs (imperative, discouraged)
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -145,6 +145,8 @@ namespace gtsam {
|
|||
return exp(logNormalizationConstant());
|
||||
}
|
||||
|
||||
using Base::error; // Expose error(const HybridValues&) method..
|
||||
|
||||
/**
|
||||
* Calculate error(x) == -log(evaluate()) for given values `x`:
|
||||
* - GaussianFactor::error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||
|
|
|
|||
Loading…
Reference in New Issue