Merge pull request #2135 from borglab/hybrid-improvements
commit
bbd0ef5a47
|
@ -71,16 +71,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// The implementation is: build the entire joint into one factor and then prune.
|
// The implementation is: build the entire joint into one factor and then prune.
|
||||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
// NOTE(Frank): This can be quite expensive *unless* the factors have already
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
DiscreteBayesNet DiscreteBayesNet::prune(
|
DiscreteBayesNet DiscreteBayesNet::prune(
|
||||||
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
|
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
|
||||||
DiscreteValues* fixedValues) const {
|
DiscreteValues* fixedValues) const {
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional joint;
|
DiscreteConditional joint = this->joint();
|
||||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
|
||||||
joint = joint * (*conditional);
|
|
||||||
|
|
||||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||||
DiscreteConditional pruned = joint;
|
DiscreteConditional pruned = joint;
|
||||||
|
@ -122,6 +120,15 @@ DiscreteBayesNet DiscreteBayesNet::prune(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
DiscreteConditional DiscreteBayesNet::joint() const {
|
||||||
|
DiscreteConditional joint;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
joint = joint * (*conditional);
|
||||||
|
|
||||||
|
return joint;
|
||||||
|
}
|
||||||
|
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
std::string DiscreteBayesNet::markdown(
|
std::string DiscreteBayesNet::markdown(
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
|
|
|
@ -136,6 +136,16 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
const std::optional<double>& marginalThreshold = {},
|
const std::optional<double>& marginalThreshold = {},
|
||||||
DiscreteValues* fixedValues = nullptr) const;
|
DiscreteValues* fixedValues = nullptr) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Multiply all conditionals into one big joint conditional
|
||||||
|
* and return it.
|
||||||
|
*
|
||||||
|
* NOTE: possibly quite expensive.
|
||||||
|
*
|
||||||
|
* @return DiscreteConditional
|
||||||
|
*/
|
||||||
|
DiscreteConditional joint() const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -108,7 +108,9 @@ static Eigen::SparseVector<double> ComputeSparseTable(
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
auto op = [&](const Assignment<Key>& assignment, double p) {
|
auto op = [&](const Assignment<Key>& assignment, double p) {
|
||||||
if (p > 0) {
|
// Check if greater than 1e-11 because we consider
|
||||||
|
// smaller than that as numerically 0
|
||||||
|
if (p > 1e-11) {
|
||||||
// Get all the keys involved in this assignment
|
// Get all the keys involved in this assignment
|
||||||
KeySet assignmentKeys;
|
KeySet assignmentKeys;
|
||||||
for (auto&& [k, _] : assignment) {
|
for (auto&& [k, _] : assignment) {
|
||||||
|
|
|
@ -53,11 +53,11 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
|
|
||||||
// Prune discrete Bayes net
|
// Prune discrete Bayes net
|
||||||
DiscreteValues fixed;
|
DiscreteValues fixed;
|
||||||
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
|
DiscreteBayesNet prunedBN =
|
||||||
|
marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
|
||||||
|
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional pruned;
|
DiscreteConditional pruned = prunedBN.joint();
|
||||||
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
|
||||||
|
|
||||||
// Set the fixed values if requested.
|
// Set the fixed values if requested.
|
||||||
if (marginalThreshold && fixedValues) {
|
if (marginalThreshold && fixedValues) {
|
||||||
|
|
|
@ -86,13 +86,28 @@ Ordering HybridSmoother::maybeComputeOrdering(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridSmoother::removeFixedValues(
|
HybridGaussianFactorGraph HybridSmoother::removeFixedValues(
|
||||||
|
const HybridGaussianFactorGraph &graph,
|
||||||
const HybridGaussianFactorGraph &newFactors) {
|
const HybridGaussianFactorGraph &newFactors) {
|
||||||
for (Key key : newFactors.discreteKeySet()) {
|
// Initialize graph
|
||||||
|
HybridGaussianFactorGraph updatedGraph(graph);
|
||||||
|
|
||||||
|
for (DiscreteKey dkey : newFactors.discreteKeys()) {
|
||||||
|
Key key = dkey.first;
|
||||||
if (fixedValues_.find(key) != fixedValues_.end()) {
|
if (fixedValues_.find(key) != fixedValues_.end()) {
|
||||||
|
// Add corresponding discrete factor to reintroduce the information
|
||||||
|
std::vector<double> probabilities(
|
||||||
|
dkey.second, (1 - *marginalThreshold_) / dkey.second);
|
||||||
|
probabilities[fixedValues_[key]] = *marginalThreshold_;
|
||||||
|
DecisionTreeFactor dtf({dkey}, probabilities);
|
||||||
|
updatedGraph.push_back(dtf);
|
||||||
|
|
||||||
|
// Remove fixed value
|
||||||
fixedValues_.erase(key);
|
fixedValues_.erase(key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return updatedGraph;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -126,6 +141,11 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
if (marginalThreshold_) {
|
||||||
|
// Remove fixed values for discrete keys which are introduced in newFactors
|
||||||
|
updatedGraph = removeFixedValues(updatedGraph, newFactors);
|
||||||
|
}
|
||||||
|
|
||||||
Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering);
|
Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering);
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
@ -145,9 +165,6 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Remove fixed values for discrete keys which are introduced in newFactors
|
|
||||||
removeFixedValues(newFactors);
|
|
||||||
|
|
||||||
#ifdef DEBUG_SMOOTHER
|
#ifdef DEBUG_SMOOTHER
|
||||||
// Print discrete keys in the bayesNetFragment:
|
// Print discrete keys in the bayesNetFragment:
|
||||||
std::cout << "Discrete keys in bayesNetFragment: ";
|
std::cout << "Discrete keys in bayesNetFragment: ";
|
||||||
|
|
|
@ -145,8 +145,19 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph,
|
Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph,
|
||||||
const std::optional<Ordering> givenOrdering);
|
const std::optional<Ordering> givenOrdering);
|
||||||
|
|
||||||
/// Remove fixed discrete values for discrete keys introduced in `newFactors`.
|
/**
|
||||||
void removeFixedValues(const HybridGaussianFactorGraph& newFactors);
|
* @brief Remove fixed discrete values for discrete keys
|
||||||
|
* introduced in `newFactors`, and reintroduce discrete factors
|
||||||
|
* with marginalThreshold_ as the probability value.
|
||||||
|
*
|
||||||
|
* @param graph The factor graph with previous conditionals added in.
|
||||||
|
* @param newFactors The new factors added to the smoother,
|
||||||
|
* used to check if a fixed discrete value has been reintroduced.
|
||||||
|
* @return HybridGaussianFactorGraph
|
||||||
|
*/
|
||||||
|
HybridGaussianFactorGraph removeFixedValues(
|
||||||
|
const HybridGaussianFactorGraph& graph,
|
||||||
|
const HybridGaussianFactorGraph& newFactors);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -152,6 +152,7 @@ class HybridBayesNet {
|
||||||
gtsam::HybridGaussianFactorGraph toFactorGraph(
|
gtsam::HybridGaussianFactorGraph toFactorGraph(
|
||||||
const gtsam::VectorValues& measurements) const;
|
const gtsam::VectorValues& measurements) const;
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesNet discreteMarginal() const;
|
||||||
gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const;
|
gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const;
|
||||||
|
|
||||||
gtsam::HybridValues optimize() const;
|
gtsam::HybridValues optimize() const;
|
||||||
|
|
Loading…
Reference in New Issue