commit
3cf15901c7
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// The implementation is: build the entire joint into one factor and then prune.
|
||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
DiscreteBayesNet DiscreteBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
|
||||
DiscreteValues* fixedValues) const {
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
joint = joint * (*conditional);
|
||||
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
DiscreteConditional pruned = joint;
|
||||
pruned.prune(maxNrLeaves);
|
||||
|
||||
DiscreteValues deadModesValues;
|
||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||
// then we run dead mode removal.
|
||||
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
|
||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||
for (auto dkey : pruned.discreteKeys()) {
|
||||
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||
|
||||
int index = -1;
|
||||
auto threshold = (probabilities.array() > *marginalThreshold);
|
||||
// If atleast 1 value is non-zero, then we can find the index
|
||||
// Else if all are zero, index would be set to 0 which is incorrect
|
||||
if (!threshold.isZero()) {
|
||||
threshold.maxCoeff(&index);
|
||||
}
|
||||
|
||||
if (index >= 0) {
|
||||
deadModesValues.emplace(dkey.first, index);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the modes (imperative)
|
||||
pruned.removeDiscreteModes(deadModesValues);
|
||||
|
||||
// Set the fixed values if requested.
|
||||
if (fixedValues) {
|
||||
*fixedValues = deadModesValues;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the resulting DiscreteBayesNet.
|
||||
DiscreteBayesNet result;
|
||||
if (pruned.keys().size() > 0) result.push_back(pruned);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
|
|
|
@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
|||
*/
|
||||
DiscreteValues sample(DiscreteValues given) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the Bayes net
|
||||
*
|
||||
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||
* @param marginalThreshold If given, threshold on marginals to prune variables.
|
||||
* @param fixedValues If given, return the fixed values removed.
|
||||
* @return A new DiscreteBayesNet with pruned conditionals.
|
||||
*/
|
||||
DiscreteBayesNet prune(size_t maxNrLeaves,
|
||||
const std::optional<double>& marginalThreshold = {},
|
||||
DiscreteValues* fixedValues = nullptr) const;
|
||||
|
||||
///@}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -73,23 +73,77 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
|||
}
|
||||
|
||||
/**
|
||||
* Insert key-assignment pair.
|
||||
* Throws an invalid_argument exception if
|
||||
* any keys to be inserted are already used. */
|
||||
* @brief Insert key-assignment pair.
|
||||
*
|
||||
* @param assignment The key-assignment pair to insert.
|
||||
* @return DiscreteValues& Reference to the updated DiscreteValues object.
|
||||
* @throws std::invalid_argument if any keys to be inserted are already used.
|
||||
*/
|
||||
DiscreteValues& insert(const std::pair<Key, size_t>& assignment);
|
||||
|
||||
/** Insert all values from \c values. Throws an invalid_argument exception if
|
||||
* any keys to be inserted are already used. */
|
||||
/**
|
||||
* @brief Insert all values from another DiscreteValues object.
|
||||
*
|
||||
* @param values The DiscreteValues object containing values to insert.
|
||||
* @return DiscreteValues& Reference to the updated DiscreteValues object.
|
||||
* @throws std::invalid_argument if any keys to be inserted are already used.
|
||||
*/
|
||||
DiscreteValues& insert(const DiscreteValues& values);
|
||||
|
||||
/** For all key/value pairs in \c values, replace values with corresponding
|
||||
* keys in this object with those in \c values. Throws std::out_of_range if
|
||||
* any keys in \c values are not present in this object. */
|
||||
/**
|
||||
* @brief Update values with corresponding keys from another DiscreteValues
|
||||
* object.
|
||||
*
|
||||
* @param values The DiscreteValues object containing values to update.
|
||||
* @return DiscreteValues& Reference to the updated DiscreteValues object.
|
||||
* @throws std::out_of_range if any keys in values are not present in this
|
||||
* object.
|
||||
*/
|
||||
DiscreteValues& update(const DiscreteValues& values);
|
||||
|
||||
/**
|
||||
* @brief Return a vector of DiscreteValues, one for each possible
|
||||
* combination of values.
|
||||
* @brief Check if the DiscreteValues contains the given key.
|
||||
*
|
||||
* @param key The key to check for.
|
||||
* @return True if the key is present, false otherwise.
|
||||
*/
|
||||
bool contains(Key key) const { return this->find(key) != this->end(); }
|
||||
|
||||
/**
|
||||
* @brief Filter values by keys.
|
||||
*
|
||||
* @param keys The keys to filter by.
|
||||
* @return DiscreteValues The filtered DiscreteValues object.
|
||||
*/
|
||||
DiscreteValues filter(const DiscreteKeys& keys) const {
|
||||
DiscreteValues result;
|
||||
for (const auto& [key, _] : keys) {
|
||||
if (auto it = this->find(key); it != this->end())
|
||||
result[key] = it->second;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the keys that are not present in the DiscreteValues object.
|
||||
*
|
||||
* @param keys The keys to check for.
|
||||
* @return DiscreteKeys Keys not present in the DiscreteValues object.
|
||||
*/
|
||||
DiscreteKeys missingKeys(const DiscreteKeys& keys) const {
|
||||
DiscreteKeys result;
|
||||
for (const auto& [key, cardinality] : keys) {
|
||||
if (!this->contains(key)) result.emplace_back(key, cardinality);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return a vector of DiscreteValues, one for each possible combination
|
||||
* of values.
|
||||
*
|
||||
* @param keys The keys to generate the Cartesian product for.
|
||||
* @return std::vector<DiscreteValues> The vector of DiscreteValues.
|
||||
*/
|
||||
static std::vector<DiscreteValues> CartesianProduct(
|
||||
const DiscreteKeys& keys) {
|
||||
|
@ -135,12 +189,14 @@ inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
|
|||
}
|
||||
|
||||
/// Free version of markdown.
|
||||
std::string GTSAM_EXPORT markdown(const DiscreteValues& values,
|
||||
std::string GTSAM_EXPORT
|
||||
markdown(const DiscreteValues& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteValues::Names& names = {});
|
||||
|
||||
/// Free version of html.
|
||||
std::string GTSAM_EXPORT html(const DiscreteValues& values,
|
||||
std::string GTSAM_EXPORT
|
||||
html(const DiscreteValues& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteValues::Names& names = {});
|
||||
|
||||
|
|
|
@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) {
|
|||
DiscreteValues(kExample).update({{12, 2}})));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test DiscreteValues::filter
|
||||
TEST(DiscreteValues, Filter) {
|
||||
DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}};
|
||||
DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values
|
||||
|
||||
EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test DiscreteValues::missingKeys
|
||||
TEST(DiscreteValues, MissingKeys) {
|
||||
DiscreteValues values = {{12, 1}, {5, 0}};
|
||||
DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing
|
||||
|
||||
EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation with a value formatter.
|
||||
TEST(DiscreteValues, markdownWithValueFormatter) {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
@ -43,114 +42,59 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// The implementation is: build the entire joint into one factor and then prune.
|
||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
|
||||
DiscreteValues *fixedValues) const {
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(HybridPruning);
|
||||
#endif
|
||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||
const DiscreteBayesNet marginal = discreteMarginal();
|
||||
|
||||
// Prune discrete Bayes net
|
||||
DiscreteValues fixed;
|
||||
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
|
||||
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (auto &&conditional : marginal) {
|
||||
joint = joint * (*conditional);
|
||||
DiscreteConditional pruned;
|
||||
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
||||
|
||||
// Set the fixed values if requested.
|
||||
if (marginalThreshold && fixedValues) {
|
||||
*fixedValues = fixed;
|
||||
}
|
||||
|
||||
// Initialize the resulting HybridBayesNet.
|
||||
HybridBayesNet result;
|
||||
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
DiscreteConditional pruned = joint;
|
||||
pruned.prune(maxNrLeaves);
|
||||
// Go through all the Gaussian conditionals, restrict them according to
|
||||
// fixed values, and then prune further.
|
||||
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
|
||||
if (conditional->isDiscrete()) continue;
|
||||
|
||||
DiscreteValues deadModesValues;
|
||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||
// then we run dead mode removal.
|
||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||
for (auto dkey : pruned.discreteKeys()) {
|
||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||
// No-op if not a HybridGaussianConditional.
|
||||
if (marginalThreshold) conditional = conditional->restrict(fixed);
|
||||
|
||||
int index = -1;
|
||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||
// If atleast 1 value is non-zero, then we can find the index
|
||||
// Else if all are zero, index would be set to 0 which is incorrect
|
||||
if (!threshold.isZero()) {
|
||||
threshold.maxCoeff(&index);
|
||||
}
|
||||
|
||||
if (index >= 0) {
|
||||
deadModesValues.emplace(dkey.first, index);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the modes (imperative)
|
||||
pruned.removeDiscreteModes(deadModesValues);
|
||||
|
||||
/*
|
||||
If the pruned discrete conditional has any keys left,
|
||||
we add it to the HybridBayesNet.
|
||||
If not, it means it is an orphan so we don't add this pruned joint,
|
||||
and instead add only the marginals below.
|
||||
*/
|
||||
if (pruned.keys().size() > 0) {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
||||
// Add the marginals for future factors
|
||||
for (auto &&[key, _] : deadModesValues) {
|
||||
result.push_back(
|
||||
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
||||
}
|
||||
|
||||
} else {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
*
|
||||
* We can later check the HybridGaussianConditional for just nullptrs.
|
||||
*/
|
||||
|
||||
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||
// per pruned discrete joint.
|
||||
for (auto &&conditional : *this) {
|
||||
// Now decide on type what to do:
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
// Prune the hybrid Gaussian conditional!
|
||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||
|
||||
if (deadModeThreshold.has_value()) {
|
||||
KeyVector deadKeys, conditionalDiscreteKeys;
|
||||
for (const auto &kv : deadModesValues) {
|
||||
deadKeys.push_back(kv.first);
|
||||
if (!prunedHybridGaussianConditional) {
|
||||
throw std::runtime_error(
|
||||
"A HybridGaussianConditional had all its conditionals pruned");
|
||||
}
|
||||
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
|
||||
conditionalDiscreteKeys.push_back(dkey.first);
|
||||
}
|
||||
// The discrete keys in the conditional are the same as the keys in the
|
||||
// dead modes, then we just get the corresponding Gaussian conditional.
|
||||
if (deadKeys == conditionalDiscreteKeys) {
|
||||
result.push_back(
|
||||
prunedHybridGaussianConditional->choose(deadModesValues));
|
||||
} else {
|
||||
// Add as-is
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
}
|
||||
} else {
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
}
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// Add the non-HybridGaussianConditional conditional
|
||||
result.push_back(gc);
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
||||
}
|
||||
// We ignore DiscreteConditional as they are already pruned and added.
|
||||
}
|
||||
|
||||
// Add the pruned discrete conditionals to the result.
|
||||
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
|
||||
result.push_back(discrete);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
*
|
||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||
* @param deadModeThreshold The threshold to check the mode marginals against.
|
||||
* If greater than this threshold, the mode gets assigned that value and is
|
||||
* considered "dead" for hybrid elimination.
|
||||
* The mode can then be removed since it only has a single possible
|
||||
* assignment.
|
||||
* @param marginalThreshold The threshold to check the mode marginals against.
|
||||
* @param fixedValues The fixed values resulting from dead mode removal.
|
||||
*
|
||||
* @note If marginal greater than this threshold, the mode gets assigned that
|
||||
* value and is considered "dead" for hybrid elimination. The mode can then be
|
||||
* removed since it only has a single possible assignment.
|
||||
|
||||
* @return A pruned HybridBayesNet
|
||||
*/
|
||||
HybridBayesNet prune(
|
||||
size_t maxNrLeaves,
|
||||
const std::optional<double> &deadModeThreshold = {}) const;
|
||||
HybridBayesNet prune(size_t maxNrLeaves,
|
||||
const std::optional<double> &marginalThreshold = {},
|
||||
DiscreteValues *fixedValues = nullptr) const;
|
||||
|
||||
/**
|
||||
* @brief Error method using HybridValues which returns specific error for
|
||||
|
|
|
@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const {
|
|||
return std::exp(logProbability(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridConditional::shared_ptr HybridConditional::restrict(
|
||||
const DiscreteValues &discreteValues) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return std::make_shared<HybridConditional>(gc);
|
||||
} else if (auto dc = asDiscrete()) {
|
||||
return std::make_shared<HybridConditional>(dc);
|
||||
};
|
||||
|
||||
auto hgc = asHybrid();
|
||||
if (!hgc)
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::restrict: conditional type not handled");
|
||||
|
||||
// Case 1: Fully determined, return corresponding Gaussian conditional
|
||||
auto parentValues = discreteValues.filter(discreteKeys_);
|
||||
if (parentValues.size() == discreteKeys_.size()) {
|
||||
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
|
||||
}
|
||||
|
||||
// Case 2: Some live parents remain, build a new tree
|
||||
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
|
||||
if (!unspecifiedParentKeys.empty()) {
|
||||
auto newTree = hgc->factors();
|
||||
for (const auto &[key, value] : parentValues) {
|
||||
newTree = newTree.choose(key, value);
|
||||
}
|
||||
return std::make_shared<HybridConditional>(
|
||||
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
|
||||
newTree));
|
||||
}
|
||||
|
||||
// Case 3: No changes needed, return original
|
||||
return std::make_shared<HybridConditional>(hgc);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a HybridConditional by choosing branches based on the given discrete
|
||||
* values. If all discrete parents are specified, return a HybridConditional
|
||||
* which is just a GaussianConditional. If this conditional is *not* a hybrid
|
||||
* conditional, just return that.
|
||||
*/
|
||||
shared_ptr restrict(const DiscreteValues& discreteValues) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
|||
if (maxNrLeaves) {
|
||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||
// all the conditionals with the same keys in bayesNetFragment.
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
|
||||
}
|
||||
|
||||
// Add the partial bayes net to the posterior bayes net.
|
||||
|
|
|
@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
|
|||
HybridGaussianFactorGraph remainingFactorGraph_;
|
||||
|
||||
/// The threshold above which we make a decision about a mode.
|
||||
std::optional<double> deadModeThreshold_;
|
||||
std::optional<double> marginalThreshold_;
|
||||
DiscreteValues fixedValues_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor
|
||||
*
|
||||
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
||||
* @param deadModeThreshold The threshold above which a mode gets assigned a
|
||||
* @param marginalThreshold The threshold above which a mode gets assigned a
|
||||
* value and is considered "dead". 0.99 is a good starting value.
|
||||
*/
|
||||
HybridSmoother(const std::optional<double> deadModeThreshold = {})
|
||||
: deadModeThreshold_(deadModeThreshold) {}
|
||||
HybridSmoother(const std::optional<double> marginalThreshold = {})
|
||||
: marginalThreshold_(marginalThreshold) {}
|
||||
|
||||
/**
|
||||
* Given new factors, perform an incremental update.
|
||||
|
|
|
@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) {
|
|||
|
||||
// prune
|
||||
auto pruned = bayesNet.prune(1);
|
||||
CHECK(pruned.at(1)->asHybrid());
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents());
|
||||
CHECK(pruned.at(0)->asHybrid());
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
|
||||
EXPECT(!pruned.equals(bayesNet));
|
||||
|
||||
// error
|
||||
|
@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
|||
const double pruneDeadVariables = 0.99;
|
||||
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
||||
|
||||
// First conditional is still the same: P( x0 | x1 m0)
|
||||
EXPECT(prunedBayesNet.at(0)->isHybrid());
|
||||
|
||||
// Check that hybrid conditional that only depend on M1
|
||||
// is now Gaussian and not Hybrid
|
||||
EXPECT(prunedBayesNet.at(1)->isContinuous());
|
||||
|
||||
// Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0)
|
||||
EXPECT(prunedBayesNet.at(0)->isHybrid());
|
||||
|
||||
// Check that discrete joint only has M0 and not (M0, M1)
|
||||
// since M0 is removed
|
||||
KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys();
|
||||
EXPECT(KeyVector{M(0)} == actual_keys);
|
||||
|
||||
// Check that hybrid conditionals that only depend on M1
|
||||
// are now Gaussian and not Hybrid
|
||||
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
||||
EXPECT(prunedBayesNet.at(1)->isDiscrete());
|
||||
EXPECT(prunedBayesNet.at(2)->isHybrid());
|
||||
// Only P(X2 | X1, M1) depends on M1,
|
||||
// so it gets convert to a Gaussian P(X2 | X1)
|
||||
EXPECT(prunedBayesNet.at(3)->isContinuous());
|
||||
EXPECT(prunedBayesNet.at(4)->isHybrid());
|
||||
auto joint = prunedBayesNet.at(3)->asDiscrete();
|
||||
EXPECT(joint);
|
||||
EXPECT(joint->keys() == KeyVector{M(0)});
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) {
|
|||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||
double pruned_logProbability = 0;
|
||||
pruned_logProbability +=
|
||||
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
|
||||
prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues);
|
||||
pruned_logProbability +=
|
||||
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
|
||||
pruned_logProbability +=
|
||||
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
|
||||
pruned_logProbability +=
|
||||
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues);
|
||||
prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues);
|
||||
|
||||
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
|
||||
|
||||
|
@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
};
|
||||
|
||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||
CHECK(pruned.at(0)->asDiscrete());
|
||||
auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete();
|
||||
CHECK(pruned.at(4)->asDiscrete());
|
||||
auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
|
||||
auto discrete_conditional_tree =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||
pruned_discrete_conditionals);
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
@ -238,22 +239,27 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
|||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
namespace two_mode_measurement {
|
||||
// Create a two key conditional:
|
||||
const DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||
const std::vector<GaussianConditional::shared_ptr> gcs = {
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1),
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2),
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3),
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)};
|
||||
const HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||
const auto hgc =
|
||||
std::make_shared<HybridGaussianConditional>(modes, conditionals);
|
||||
} // namespace two_mode_measurement
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
||||
// DecisionTreeFactor with 3 keys:
|
||||
TEST(HybridGaussianConditional, Prune) {
|
||||
// Create a two key conditional:
|
||||
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||
std::vector<GaussianConditional::shared_ptr> gcs;
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
gcs.push_back(
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
|
||||
}
|
||||
auto empty = std::make_shared<GaussianConditional>();
|
||||
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||
HybridGaussianConditional hgc(modes, conditionals);
|
||||
using two_mode_measurement::hgc;
|
||||
|
||||
DiscreteKeys keys = modes;
|
||||
DiscreteKeys keys = two_mode_measurement::modes;
|
||||
keys.push_back({M(3), 2});
|
||||
{
|
||||
for (size_t i = 0; i < 8; i++) {
|
||||
|
@ -262,7 +268,7 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
// Prune the HybridGaussianConditional
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||
}
|
||||
|
@ -273,14 +279,14 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||
|
||||
// Check that the minimum negLogConstant is set correctly
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
|
||||
hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
|
||||
pruned->negLogConstant(), 1e-9);
|
||||
}
|
||||
{
|
||||
|
@ -289,18 +295,48 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||
|
||||
// Check that the minimum negLogConstant is correct
|
||||
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
/* *************************************************************************
|
||||
* This test verifies the behavior of the restrict method in different
|
||||
* scenarios:
|
||||
* - When no restrictions are applied.
|
||||
* - When one parent is restricted.
|
||||
* - When two parents are restricted.
|
||||
* - When the restriction results in a Gaussian conditional.
|
||||
*/
|
||||
TEST(HybridGaussianConditional, Restrict) {
|
||||
// Create a HybridConditional with two discrete parents P(z0|m0,m1)
|
||||
const auto hc =
|
||||
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
|
||||
|
||||
const HybridConditional::shared_ptr same = hc->restrict({});
|
||||
EXPECT(same->isHybrid());
|
||||
EXPECT(same->asHybrid()->nrComponents() == 4);
|
||||
|
||||
const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}});
|
||||
EXPECT(oneParent->isHybrid());
|
||||
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
||||
|
||||
const HybridConditional::shared_ptr oneParent2 =
|
||||
hc->restrict({{M(7), 0}, {M(1), 0}});
|
||||
EXPECT(oneParent2->isHybrid());
|
||||
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
|
||||
|
||||
const HybridConditional::shared_ptr gaussian =
|
||||
hc->restrict({{M(1), 0}, {M(2), 1}});
|
||||
EXPECT(gaussian->asGaussian());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
|
|
|
@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) {
|
|||
}
|
||||
|
||||
EXPECT_LONGS_EQUAL(11,
|
||||
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues());
|
||||
smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues());
|
||||
|
||||
// Get the continuous delta update as well as
|
||||
// the optimal discrete assignment.
|
||||
|
@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) {
|
|||
}
|
||||
|
||||
EXPECT_LONGS_EQUAL(14,
|
||||
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues());
|
||||
smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues());
|
||||
|
||||
// Get the continuous delta update as well as
|
||||
// the optimal discrete assignment.
|
||||
|
|
Loading…
Reference in New Issue