diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index cc3c8b8d7..b3f0d69b0 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -71,7 +71,7 @@ namespace gtsam { static inline double id(const double& x) { return x; } }; - AlgebraicDecisionTree() : Base(1.0) {} + AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} // Explicitly non-explicit constructor AlgebraicDecisionTree(const Base& add) : Base(add) {} diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index fe87795fe..f0d53c416 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -273,4 +273,10 @@ AlgebraicDecisionTree HybridBayesNet::error( return error_tree; } +AlgebraicDecisionTree HybridBayesNet::probPrime( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree = this->error(continuousValues); + return error_tree.apply([](double error) { return exp(-error); }); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f296ba644..c6ac6dcec 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -144,6 +144,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + /** + * @brief Compute unnormalized probability for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the + * probability. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree probPrime( + const VectorValues &continuousValues) const; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 9264089aa..d0d2b8d15 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -439,4 +439,53 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { return ordering; } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree(0.0); + + for (size_t idx = 0; idx < size(); idx++) { + AlgebraicDecisionTree factor_error; + + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixtureFactor::shared_ptr gaussianMixture = + boost::static_pointer_cast(factors_.at(idx)); + factor_error = gaussianMixture->error(continuousValues); + + if (idx == 0) { + error_tree = factor_error; + } else { + error_tree = error_tree + factor_error; + } + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + auto hybridGaussianFactor = + boost::static_pointer_cast(factors_.at(idx)); + GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); + + double error = gaussian->error(continuousValues); + error_tree = error_tree.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we skip. + continue; + } + } + + return error_tree; +} + +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree = this->error(continuousValues); + AlgebraicDecisionTree prob_tree = + error_tree.apply([](double error) { return exp(-error); }); + return prob_tree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 6a0362500..c7e9aa60d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -41,7 +41,7 @@ class JacobianFactor; /** * @brief Main elimination function for HybridGaussianFactorGraph. - * + * * @param factors The factor graph to eliminate. * @param keys The elimination ordering. * @return The conditional on the ordering keys and the remaining factors. @@ -170,6 +170,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } + /** + * @brief Compute error for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the error. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + + /** + * @brief Compute unnormalized probability for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the + * probability. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree probPrime( + const VectorValues& continuousValues) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index ed6b97ab0..7877461b6 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) { EXPECT(assert_equal(expected_discrete, result.discrete())); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph error and unnormalized probabilities +TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { + Switching s(3); + + HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + + Ordering hybridOrdering = graph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + graph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = graph.error(delta.continuous()); + + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + + auto probs = graph.probPrime(delta.continuous()); + std::vector prob_leaves = {0.36793249, 0.61247742, 0.59489556, + 0.99029064}; + AlgebraicDecisionTree expected_probs(discrete_keys, prob_leaves); + + // regression + EXPECT(assert_equal(expected_probs, probs, 1e-7)); +} + /* ************************************************************************* */ int main() { TestResult tr;