Merge pull request #2135 from borglab/hybrid-improvements

release/4.3a0
Varun Agrawal 2025-05-15 22:53:57 -04:00 committed by GitHub
commit bbd0ef5a47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 15 deletions

View File

@ -71,16 +71,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
/* ************************************************************************* */ /* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune. // The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already // NOTE(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound // been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional. // search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune( DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& marginalThreshold, size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const { DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint; DiscreteConditional joint = this->joint();
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive. // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint; DiscreteConditional pruned = joint;
@ -122,6 +120,15 @@ DiscreteBayesNet DiscreteBayesNet::prune(
return result; return result;
} }
/* *********************************************************************** */
DiscreteConditional DiscreteBayesNet::joint() const {
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);
return joint;
}
/* *********************************************************************** */ /* *********************************************************************** */
std::string DiscreteBayesNet::markdown( std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,

View File

@ -136,6 +136,16 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
const std::optional<double>& marginalThreshold = {}, const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const; DiscreteValues* fixedValues = nullptr) const;
/**
* @brief Multiply all conditionals into one big joint conditional
* and return it.
*
* NOTE: possibly quite expensive.
*
* @return DiscreteConditional
*/
DiscreteConditional joint() const;
///@} ///@}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -108,7 +108,9 @@ static Eigen::SparseVector<double> ComputeSparseTable(
* *
*/ */
auto op = [&](const Assignment<Key>& assignment, double p) { auto op = [&](const Assignment<Key>& assignment, double p) {
if (p > 0) { // Check if greater than 1e-11 because we consider
// smaller than that as numerically 0
if (p > 1e-11) {
// Get all the keys involved in this assignment // Get all the keys involved in this assignment
KeySet assignmentKeys; KeySet assignmentKeys;
for (auto&& [k, _] : assignment) { for (auto&& [k, _] : assignment) {

View File

@ -53,11 +53,11 @@ HybridBayesNet HybridBayesNet::prune(
// Prune discrete Bayes net // Prune discrete Bayes net
DiscreteValues fixed; DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); DiscreteBayesNet prunedBN =
marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional pruned; DiscreteConditional pruned = prunedBN.joint();
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
// Set the fixed values if requested. // Set the fixed values if requested.
if (marginalThreshold && fixedValues) { if (marginalThreshold && fixedValues) {

View File

@ -86,13 +86,28 @@ Ordering HybridSmoother::maybeComputeOrdering(
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridSmoother::removeFixedValues( HybridGaussianFactorGraph HybridSmoother::removeFixedValues(
const HybridGaussianFactorGraph &graph,
const HybridGaussianFactorGraph &newFactors) { const HybridGaussianFactorGraph &newFactors) {
for (Key key : newFactors.discreteKeySet()) { // Initialize graph
HybridGaussianFactorGraph updatedGraph(graph);
for (DiscreteKey dkey : newFactors.discreteKeys()) {
Key key = dkey.first;
if (fixedValues_.find(key) != fixedValues_.end()) { if (fixedValues_.find(key) != fixedValues_.end()) {
// Add corresponding discrete factor to reintroduce the information
std::vector<double> probabilities(
dkey.second, (1 - *marginalThreshold_) / dkey.second);
probabilities[fixedValues_[key]] = *marginalThreshold_;
DecisionTreeFactor dtf({dkey}, probabilities);
updatedGraph.push_back(dtf);
// Remove fixed value
fixedValues_.erase(key); fixedValues_.erase(key);
} }
} }
return updatedGraph;
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -126,6 +141,11 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
<< std::endl; << std::endl;
#endif #endif
if (marginalThreshold_) {
// Remove fixed values for discrete keys which are introduced in newFactors
updatedGraph = removeFixedValues(updatedGraph, newFactors);
}
Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering); Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
@ -145,9 +165,6 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
} }
#endif #endif
// Remove fixed values for discrete keys which are introduced in newFactors
removeFixedValues(newFactors);
#ifdef DEBUG_SMOOTHER #ifdef DEBUG_SMOOTHER
// Print discrete keys in the bayesNetFragment: // Print discrete keys in the bayesNetFragment:
std::cout << "Discrete keys in bayesNetFragment: "; std::cout << "Discrete keys in bayesNetFragment: ";

View File

@ -145,8 +145,19 @@ class GTSAM_EXPORT HybridSmoother {
Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph, Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph,
const std::optional<Ordering> givenOrdering); const std::optional<Ordering> givenOrdering);
/// Remove fixed discrete values for discrete keys introduced in `newFactors`. /**
void removeFixedValues(const HybridGaussianFactorGraph& newFactors); * @brief Remove fixed discrete values for discrete keys
* introduced in `newFactors`, and reintroduce discrete factors
* with marginalThreshold_ as the probability value.
*
* @param graph The factor graph with previous conditionals added in.
* @param newFactors The new factors added to the smoother,
* used to check if a fixed discrete value has been reintroduced.
* @return HybridGaussianFactorGraph
*/
HybridGaussianFactorGraph removeFixedValues(
const HybridGaussianFactorGraph& graph,
const HybridGaussianFactorGraph& newFactors);
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -152,6 +152,7 @@ class HybridBayesNet {
gtsam::HybridGaussianFactorGraph toFactorGraph( gtsam::HybridGaussianFactorGraph toFactorGraph(
const gtsam::VectorValues& measurements) const; const gtsam::VectorValues& measurements) const;
gtsam::DiscreteBayesNet discreteMarginal() const;
gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const; gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const;
gtsam::HybridValues optimize() const; gtsam::HybridValues optimize() const;