diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 56265b0a4..8c04cb91c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace gtsam { @@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } +/* ************************************************************************* */ +// The implementation is: build the entire joint into one factor and then prune. +// TODO(Frank): This can be quite expensive *unless* the factors have already +// been pruned before. Another, possibly faster approach is branch and bound +// search to find the K-best leaves and then create a single pruned conditional. +DiscreteBayesNet DiscreteBayesNet::prune( + size_t maxNrLeaves, const std::optional& marginalThreshold, + DiscreteValues* fixedValues) const { + // Multiply into one big conditional. NOTE: possibly quite expensive. + DiscreteConditional joint; + for (const DiscreteConditional::shared_ptr& conditional : *this) + joint = joint * (*conditional); + + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + pruned.prune(maxNrLeaves); + + DiscreteValues deadModesValues; + // If we have a dead mode threshold and discrete variables left after pruning, + // then we run dead mode removal. + if (marginalThreshold.has_value() && pruned.keys().size() > 0) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + const Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > *marginalThreshold); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.emplace(dkey.first, index); + } + } + + // Remove the modes (imperative) + pruned.removeDiscreteModes(deadModesValues); + + // Set the fixed values if requested. + if (fixedValues) { + *fixedValues = deadModesValues; + } + } + + // Return the resulting DiscreteBayesNet. + DiscreteBayesNet result; + if (pruned.keys().size() > 0) result.push_back(pruned); + return result; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 738b91aa5..eea1739f6 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; + /** + * @brief Prune the Bayes net + * + * @param maxNrLeaves The maximum number of leaves to keep. + * @param marginalThreshold If given, threshold on marginals to prune variables. + * @param fixedValues If given, return the fixed values removed. + * @return A new DiscreteBayesNet with pruned conditionals. + */ + DiscreteBayesNet prune(size_t maxNrLeaves, + const std::optional& marginalThreshold = {}, + DiscreteValues* fixedValues = nullptr) const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index df4ecdbff..0644e0c16 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -68,28 +68,82 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); // insert in base class; - std::pair insert( const value_type& value ){ + std::pair insert(const value_type& value) { return Base::insert(value); } /** - * Insert key-assignment pair. - * Throws an invalid_argument exception if - * any keys to be inserted are already used. */ + * @brief Insert key-assignment pair. + * + * @param assignment The key-assignment pair to insert. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::invalid_argument if any keys to be inserted are already used. + */ DiscreteValues& insert(const std::pair& assignment); - /** Insert all values from \c values. Throws an invalid_argument exception if - * any keys to be inserted are already used. */ + /** + * @brief Insert all values from another DiscreteValues object. + * + * @param values The DiscreteValues object containing values to insert. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::invalid_argument if any keys to be inserted are already used. + */ DiscreteValues& insert(const DiscreteValues& values); - /** For all key/value pairs in \c values, replace values with corresponding - * keys in this object with those in \c values. Throws std::out_of_range if - * any keys in \c values are not present in this object. */ + /** + * @brief Update values with corresponding keys from another DiscreteValues + * object. + * + * @param values The DiscreteValues object containing values to update. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::out_of_range if any keys in values are not present in this + * object. + */ DiscreteValues& update(const DiscreteValues& values); /** - * @brief Return a vector of DiscreteValues, one for each possible - * combination of values. + * @brief Check if the DiscreteValues contains the given key. + * + * @param key The key to check for. + * @return True if the key is present, false otherwise. + */ + bool contains(Key key) const { return this->find(key) != this->end(); } + + /** + * @brief Filter values by keys. + * + * @param keys The keys to filter by. + * @return DiscreteValues The filtered DiscreteValues object. + */ + DiscreteValues filter(const DiscreteKeys& keys) const { + DiscreteValues result; + for (const auto& [key, _] : keys) { + if (auto it = this->find(key); it != this->end()) + result[key] = it->second; + } + return result; + } + + /** + * @brief Return the keys that are not present in the DiscreteValues object. + * + * @param keys The keys to check for. + * @return DiscreteKeys Keys not present in the DiscreteValues object. + */ + DiscreteKeys missingKeys(const DiscreteKeys& keys) const { + DiscreteKeys result; + for (const auto& [key, cardinality] : keys) { + if (!this->contains(key)) result.emplace_back(key, cardinality); + } + return result; + } + + /** + * @brief Return a vector of DiscreteValues, one for each possible combination + * of values. + * + * @param keys The keys to generate the Cartesian product for. + * @return std::vector The vector of DiscreteValues. */ static std::vector CartesianProduct( const DiscreteKeys& keys) { @@ -135,14 +189,16 @@ inline std::vector cartesianProduct(const DiscreteKeys& keys) { } /// Free version of markdown. -std::string GTSAM_EXPORT markdown(const DiscreteValues& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const DiscreteValues::Names& names = {}); +std::string GTSAM_EXPORT +markdown(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); /// Free version of html. -std::string GTSAM_EXPORT html(const DiscreteValues& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const DiscreteValues::Names& names = {}); +std::string GTSAM_EXPORT +html(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); // traits template <> diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp index 47989481c..4667bc5ea 100644 --- a/gtsam/discrete/tests/testDiscreteValues.cpp +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) { DiscreteValues(kExample).update({{12, 2}}))); } +/* ************************************************************************* */ +// Test DiscreteValues::filter +TEST(DiscreteValues, Filter) { + DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}}; + DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values + + EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys))); +} + +/* ************************************************************************* */ +// Test DiscreteValues::missingKeys +TEST(DiscreteValues, MissingKeys) { + DiscreteValues values = {{12, 1}, {5, 0}}; + DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing + + EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys))); +} + /* ************************************************************************* */ // Check markdown representation with a value formatter. TEST(DiscreteValues, markdownWithValueFormatter) { diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e64284a94..5353fe2e0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -43,115 +42,60 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { } /* ************************************************************************* */ -// The implementation is: build the entire joint into one factor and then prune. -// TODO(Frank): This can be quite expensive *unless* the factors have already -// been pruned before. Another, possibly faster approach is branch and bound -// search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( - size_t maxNrLeaves, const std::optional &deadModeThreshold) const { + size_t maxNrLeaves, const std::optional &marginalThreshold, + DiscreteValues *fixedValues) const { +#if GTSAM_HYBRID_TIMING + gttic_(HybridPruning); +#endif // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); + // Prune discrete Bayes net + DiscreteValues fixed; + auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); + // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional joint; - for (auto &&conditional : marginal) { - joint = joint * (*conditional); + DiscreteConditional pruned; + for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); + + // Set the fixed values if requested. + if (marginalThreshold && fixedValues) { + *fixedValues = fixed; } - // Initialize the resulting HybridBayesNet. HybridBayesNet result; - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - DiscreteConditional pruned = joint; - pruned.prune(maxNrLeaves); + // Go through all the Gaussian conditionals, restrict them according to + // fixed values, and then prune further. + for (std::shared_ptr conditional : *this) { + if (conditional->isDiscrete()) continue; - DiscreteValues deadModesValues; - // If we have a dead mode threshold and discrete variables left after pruning, - // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { - DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); - for (auto dkey : pruned.discreteKeys()) { - Vector probabilities = marginals.marginalProbabilities(dkey); + // No-op if not a HybridGaussianConditional. + if (marginalThreshold) conditional = conditional->restrict(fixed); - int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); - // If atleast 1 value is non-zero, then we can find the index - // Else if all are zero, index would be set to 0 which is incorrect - if (!threshold.isZero()) { - threshold.maxCoeff(&index); - } - - if (index >= 0) { - deadModesValues.emplace(dkey.first, index); - } - } - - // Remove the modes (imperative) - pruned.removeDiscreteModes(deadModesValues); - - /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } - - // Add the marginals for future factors - for (auto &&[key, _] : deadModesValues) { - result.push_back( - std::dynamic_pointer_cast(marginals(key))); - } - - } else { - result.emplace_shared(pruned); - } - - /* To prune, we visitWith every leaf in the HybridGaussianConditional. - * For each leaf, using the assignment we can check the discrete decision tree - * for 0.0 probability, then just set the leaf to a nullptr. - * - * We can later check the HybridGaussianConditional for just nullptrs. - */ - - // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned discrete joint. - for (auto &&conditional : *this) { + // Now decide on type what to do: if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - - if (deadModeThreshold.has_value()) { - KeyVector deadKeys, conditionalDiscreteKeys; - for (const auto &kv : deadModesValues) { - deadKeys.push_back(kv.first); - } - for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { - conditionalDiscreteKeys.push_back(dkey.first); - } - // The discrete keys in the conditional are the same as the keys in the - // dead modes, then we just get the corresponding Gaussian conditional. - if (deadKeys == conditionalDiscreteKeys) { - result.push_back( - prunedHybridGaussianConditional->choose(deadModesValues)); - } else { - // Add as-is - result.push_back(prunedHybridGaussianConditional); - } - } else { - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); + if (!prunedHybridGaussianConditional) { + throw std::runtime_error( + "A HybridGaussianConditional had all its conditionals pruned"); } - + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); - } - // We ignore DiscreteConditional as they are already pruned and added. + } else + throw std::runtime_error( + "HybrdiBayesNet::prune: Unknown HybridConditional type."); } + // Add the pruned discrete conditionals to the result. + for (const DiscreteConditional::shared_ptr &discrete : prunedBN) + result.push_back(discrete); + return result; } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fb05e2407..0840cb381 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param deadModeThreshold The threshold to check the mode marginals against. - * If greater than this threshold, the mode gets assigned that value and is - * considered "dead" for hybrid elimination. - * The mode can then be removed since it only has a single possible - * assignment. + * @param marginalThreshold The threshold to check the mode marginals against. + * @param fixedValues The fixed values resulting from dead mode removal. + * + * @note If marginal greater than this threshold, the mode gets assigned that + * value and is considered "dead" for hybrid elimination. The mode can then be + * removed since it only has a single possible assignment. + * @return A pruned HybridBayesNet */ - HybridBayesNet prune( - size_t maxNrLeaves, - const std::optional &deadModeThreshold = {}) const; + HybridBayesNet prune(size_t maxNrLeaves, + const std::optional &marginalThreshold = {}, + DiscreteValues *fixedValues = nullptr) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 97ec1a1f8..257eca314 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } +/* ************************************************************************ */ +HybridConditional::shared_ptr HybridConditional::restrict( + const DiscreteValues &discreteValues) const { + if (auto gc = asGaussian()) { + return std::make_shared(gc); + } else if (auto dc = asDiscrete()) { + return std::make_shared(dc); + }; + + auto hgc = asHybrid(); + if (!hgc) + throw std::runtime_error( + "HybridConditional::restrict: conditional type not handled"); + + // Case 1: Fully determined, return corresponding Gaussian conditional + auto parentValues = discreteValues.filter(discreteKeys_); + if (parentValues.size() == discreteKeys_.size()) { + return std::make_shared(hgc->choose(parentValues)); + } + + // Case 2: Some live parents remain, build a new tree + auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); + if (!unspecifiedParentKeys.empty()) { + auto newTree = hgc->factors(); + for (const auto &[key, value] : parentValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(unspecifiedParentKeys, + newTree)); + } + + // Case 3: No changes needed, return original + return std::make_shared(hgc); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3cf5b80e5..075fbe411 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional return true; } + /** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. If this conditional is *not* a hybrid + * conditional, just return that. + */ + shared_ptr restrict(const DiscreteValues& discreteValues) const; + /// @} private: diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 45320896a..594c12825 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_); } // Add the partial bayes net to the posterior bayes net. diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index c3f022c62..2f7bfcebb 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother { HybridGaussianFactorGraph remainingFactorGraph_; /// The threshold above which we make a decision about a mode. - std::optional deadModeThreshold_; + std::optional marginalThreshold_; + DiscreteValues fixedValues_; public: /** * @brief Constructor * * @param removeDeadModes Flag indicating whether to remove dead modes. - * @param deadModeThreshold The threshold above which a mode gets assigned a + * @param marginalThreshold The threshold above which a mode gets assigned a * value and is considered "dead". 0.99 is a good starting value. */ - HybridSmoother(const std::optional deadModeThreshold = {}) - : deadModeThreshold_(deadModeThreshold) {} + HybridSmoother(const std::optional marginalThreshold = {}) + : marginalThreshold_(marginalThreshold) {} /** * Given new factors, perform an incremental update. diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 86dcd48e4..3ddad23ff 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) { // prune auto pruned = bayesNet.prune(1); - CHECK(pruned.at(1)->asHybrid()); - EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents()); + CHECK(pruned.at(0)->asHybrid()); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); EXPECT(!pruned.equals(bayesNet)); // error @@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) { const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + // First conditional is still the same: P( x0 | x1 m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + + // Check that hybrid conditional that only depend on M1 + // is now Gaussian and not Hybrid + EXPECT(prunedBayesNet.at(1)->isContinuous()); + + // Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + // Check that discrete joint only has M0 and not (M0, M1) // since M0 is removed - KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); - EXPECT(KeyVector{M(0)} == actual_keys); - - // Check that hybrid conditionals that only depend on M1 - // are now Gaussian and not Hybrid - EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isDiscrete()); - EXPECT(prunedBayesNet.at(2)->isHybrid()); - // Only P(X2 | X1, M1) depends on M1, - // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(3)->isContinuous()); - EXPECT(prunedBayesNet.at(4)->isHybrid()); + auto joint = prunedBayesNet.at(3)->asDiscrete(); + EXPECT(joint); + EXPECT(joint->keys() == KeyVector{M(0)}); } /* ****************************************************************************/ @@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) { const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += - prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); + prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); pruned_logProbability += - prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); + prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues); double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); @@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - CHECK(pruned.at(0)->asDiscrete()); - auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete(); + CHECK(pruned.at(4)->asDiscrete()); + auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals); diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 8bb83cac4..88a4fa485 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -238,22 +239,27 @@ TEST(HybridGaussianConditional, Likelihood2) { EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); } +/* ************************************************************************* */ +namespace two_mode_measurement { +// Create a two key conditional: +const DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; +const std::vector gcs = { + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)}; +const HybridGaussianConditional::Conditionals conditionals(modes, gcs); +const auto hgc = + std::make_shared(modes, conditionals); +} // namespace two_mode_measurement + /* ************************************************************************* */ // Test pruning a HybridGaussianConditional with two discrete keys, based on a // DecisionTreeFactor with 3 keys: TEST(HybridGaussianConditional, Prune) { - // Create a two key conditional: - DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; - std::vector gcs; - for (size_t i = 0; i < 4; i++) { - gcs.push_back( - GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1)); - } - auto empty = std::make_shared(); - HybridGaussianConditional::Conditionals conditionals(modes, gcs); - HybridGaussianConditional hgc(modes, conditionals); + using two_mode_measurement::hgc; - DiscreteKeys keys = modes; + DiscreteKeys keys = two_mode_measurement::modes; keys.push_back({M(3), 2}); { for (size_t i = 0; i < 8; i++) { @@ -262,7 +268,7 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -273,14 +279,14 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); // Check that the minimum negLogConstant is set correctly EXPECT_DOUBLES_EQUAL( - hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), + hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), pruned->negLogConstant(), 1e-9); } { @@ -289,18 +295,48 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); // Check that the minimum negLogConstant is correct - EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); + EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9); } } /* ************************************************************************* + * This test verifies the behavior of the restrict method in different + * scenarios: + * - When no restrictions are applied. + * - When one parent is restricted. + * - When two parents are restricted. + * - When the restriction results in a Gaussian conditional. */ +TEST(HybridGaussianConditional, Restrict) { + // Create a HybridConditional with two discrete parents P(z0|m0,m1) + const auto hc = + std::make_shared(two_mode_measurement::hgc); + + const HybridConditional::shared_ptr same = hc->restrict({}); + EXPECT(same->isHybrid()); + EXPECT(same->asHybrid()->nrComponents() == 4); + + const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}}); + EXPECT(oneParent->isHybrid()); + EXPECT(oneParent->asHybrid()->nrComponents() == 2); + + const HybridConditional::shared_ptr oneParent2 = + hc->restrict({{M(7), 0}, {M(1), 0}}); + EXPECT(oneParent2->isHybrid()); + EXPECT(oneParent2->asHybrid()->nrComponents() == 2); + + const HybridConditional::shared_ptr gaussian = + hc->restrict({{M(1), 0}, {M(2), 1}}); + EXPECT(gaussian->asGaussian()); +} + +/* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 3a0f376cc..97a302faf 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) { } EXPECT_LONGS_EQUAL(11, - smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) { } EXPECT_LONGS_EQUAL(14, - smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment.