diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 1b12255ac..e662b9c81 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -257,29 +257,7 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { /* ************************************************************************* */ double HybridBayesNet::evaluate(const HybridValues &values) const { - const DiscreteValues &discreteValues = values.discrete(); - const VectorValues &continuousValues = values.continuous(); - - double error = 0.0, probability = 1.0; - - // Iterate over each conditional. - for (auto &&conditional : *this) { - // TODO: should be delegated to derived classes. - if (auto gm = conditional->asMixture()) { - const auto component = (*gm)(discreteValues); - error += component->error(continuousValues); - - } else if (auto gc = conditional->asGaussian()) { - // If continuous only, evaluate the probability and multiply. - error += gc->error(continuousValues); - - } else if (auto dc = conditional->asDiscrete()) { - // Conditional is discrete-only, so return its probability. - probability *= dc->operator()(discreteValues); - } - } - - return probability * exp(-error); + return exp(-error(values)); } /* ************************************************************************* */ @@ -317,12 +295,6 @@ HybridValues HybridBayesNet::sample() const { return sample(&kRandomNumberGenerator); } -/* ************************************************************************* */ -double HybridBayesNet::error(const HybridValues &values) const { - GaussianBayesNet gbn = choose(values.discrete()); - return gbn.error(values.continuous()); -} - /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::error( const VectorValues &continuousValues) const { @@ -332,19 +304,15 @@ AlgebraicDecisionTree HybridBayesNet::error( for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { // If conditional is hybrid, select based on assignment and compute error. - AlgebraicDecisionTree conditional_error = - gm->error(continuousValues); - - error_tree = error_tree + conditional_error; + error_tree = error_tree + gm->error(continuousValues); } else if (auto gc = conditional->asGaussian()) { - // If continuous only, get the (double) error - // and add it to the error_tree + // If continuous, get the (double) error and add it to the error_tree double error = gc->error(continuousValues); // Add the computed error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); } else if (auto dc = conditional->asDiscrete()) { - // Conditional is discrete-only, we skip. + // TODO(dellaert): if discrete, we need to add error in the right branch? continue; } } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index dd8d38a4c..c5a16f9dd 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -187,14 +187,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. HybridBayesNet prune(size_t maxNrLeaves); - /** - * @brief 0.5 * sum of squared Mahalanobis distances - * for a specific discrete assignment. - * - * @param values Continuous values and discrete assignment. - * @return double - */ - double error(const HybridValues &values) const; + using Base::error; // Expose error(const HybridValues&) method.. /** * @brief Compute conditional error for each discrete assignment, diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d0351afbc..c59187f4e 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -463,24 +463,6 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( return error_tree; } -/* ************************************************************************ */ -double HybridGaussianFactorGraph::error(const HybridValues &values) const { - double error = 0.0; - for (auto &f : factors_) { - if (auto hf = dynamic_pointer_cast(f)) { - error += hf->error(values.continuous()); - } else if (auto hf = dynamic_pointer_cast(f)) { - // TODO(dellaert): needs to change when we discard other wrappers. - error += hf->error(values); - } else if (auto dtf = dynamic_pointer_cast(f)) { - error -= log((*dtf)(values.discrete())); - } else { - throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f); - } - } - return error; -} - /* ************************************************************************ */ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { double error = this->error(values); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 0dc737250..0db4f734b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -145,6 +145,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph /// @name Standard Interface /// @{ + using Base::error; // Expose error(const HybridValues&) method.. + /** * @brief Compute error for each discrete assignment, * and return as a tree. @@ -156,14 +158,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ AlgebraicDecisionTree error(const VectorValues& continuousValues) const; - /** - * @brief Compute error given a continuous vector values - * and a discrete assignment. - * - * @return double - */ - double error(const HybridValues& values) const; - /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * for each discrete assignment, and return as a tree. diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 60aee431b..ebefb52cb 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -55,12 +55,18 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { : Base(graph) {} /// @} + /// @name Constructors + /// @{ /// Print the factor graph. void print( const std::string& s = "HybridNonlinearFactorGraph", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /// @} + /// @name Standard Interface + /// @{ + /** * @brief Linearize all the continuous factors in the * HybridNonlinearFactorGraph. @@ -70,6 +76,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { */ HybridGaussianFactorGraph::shared_ptr linearize( const Values& continuousValues) const; + /// @} }; template <> diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index d62626ea6..2c114a335 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -212,6 +212,7 @@ TEST(HybridBayesNet, Error) { HybridBayesNet::shared_ptr hybridBayesNet = s.linearizedFactorGraph.eliminateSequential(); + EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); HybridValues delta = hybridBayesNet->optimize(); auto error_tree = hybridBayesNet->error(delta.continuous()); @@ -235,26 +236,21 @@ TEST(HybridBayesNet, Error) { EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6)); // Verify error computation and check for specific error value - DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; + const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; + const HybridValues hybridValues{delta.continuous(), discrete_values}; + double error = 0; + error += hybridBayesNet->at(0)->asMixture()->error(hybridValues); + error += hybridBayesNet->at(1)->asMixture()->error(hybridValues); + error += hybridBayesNet->at(2)->asMixture()->error(hybridValues); - double total_error = 0; - for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { - if (hybridBayesNet->at(idx)->isHybrid()) { - double error = hybridBayesNet->at(idx)->asMixture()->error( - {delta.continuous(), discrete_values}); - total_error += error; - } else if (hybridBayesNet->at(idx)->isContinuous()) { - double error = - hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous()); - total_error += error; - } - } + // TODO(dellaert): the discrete errors are not added in error tree! + EXPECT_DOUBLES_EQUAL(error, error_tree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(error, pruned_error_tree(discrete_values), 1e-9); + + error += hybridBayesNet->at(3)->asDiscrete()->error(discrete_values); + error += hybridBayesNet->at(4)->asDiscrete()->error(discrete_values); + EXPECT_DOUBLES_EQUAL(error, hybridBayesNet->error(hybridValues), 1e-9); - EXPECT_DOUBLES_EQUAL( - total_error, hybridBayesNet->error({delta.continuous(), discrete_values}), - 1e-9); - EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); - EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); } /* ****************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 94a611a9e..99be3ed1c 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -60,12 +60,14 @@ TEST(HybridFactorGraph, GaussianFactorGraph) { Values linearizationPoint; linearizationPoint.insert(X(0), 0); + // Linearize the factor graph. HybridGaussianFactorGraph ghfg = *fg.linearize(linearizationPoint); + EXPECT_LONGS_EQUAL(1, ghfg.size()); - // Add a factor to the GaussianFactorGraph - ghfg.add(JacobianFactor(X(0), I_1x1, Vector1(5))); - - EXPECT_LONGS_EQUAL(2, ghfg.size()); + // Check that the error is the same for the nonlinear values. + const VectorValues zero{{X(0), Vector1(0)}}; + const HybridValues hybridValues{zero, {}, linearizationPoint}; + EXPECT_DOUBLES_EQUAL(fg.error(hybridValues), ghfg.error(hybridValues), 1e-9); } /*************************************************************************** diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index d6c1d5d8a..355fdf87e 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -61,6 +61,16 @@ bool FactorGraph::equals(const This& fg, double tol) const { return true; } +/* ************************************************************************ */ +template +double FactorGraph::error(const HybridValues &values) const { + double error = 0.0; + for (auto &f : factors_) { + error += f->error(values); + } + return error; +} + /* ************************************************************************* */ template size_t FactorGraph::nrFactors() const { diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index 68dc79d3f..2e9dd3d53 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -47,6 +47,8 @@ typedef FastVector FactorIndices; template class BayesTree; +class HybridValues; + /** Helper */ template class CRefCallPushBack { @@ -359,6 +361,9 @@ class FactorGraph { /** Get the last factor */ sharedFactor back() const { return factors_.back(); } + /** Add error for all factors. */ + double error(const HybridValues &values) const; + /// @} /// @name Modifying Factor Graphs (imperative, discouraged) /// @{ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index c838051cf..b41e1a394 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -145,6 +145,8 @@ namespace gtsam { return exp(logNormalizationConstant()); } + using Base::error; // Expose error(const HybridValues&) method.. + /** * Calculate error(x) == -log(evaluate()) for given values `x`: * - GaussianFactor::error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)