overload multifrontal elimination
parent
eb94ad90d2
commit
0938159706
|
|
@ -546,6 +546,24 @@ HybridGaussianFactorGraph::continuousDelta(
|
|||
return delta_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTree<Key, VectorValues::shared_ptr>
|
||||
HybridGaussianFactorGraph::continuousDelta(
|
||||
const DiscreteKeys &discrete_keys,
|
||||
const boost::shared_ptr<BayesTreeType> &continuousBayesTree,
|
||||
const std::vector<DiscreteValues> &assignments) const {
|
||||
// Create a decision tree of all the different VectorValues
|
||||
std::vector<VectorValues::shared_ptr> vector_values;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
VectorValues values = continuousBayesTree->optimize(assignment);
|
||||
vector_values.push_back(boost::make_shared<VectorValues>(values));
|
||||
}
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
||||
vector_values);
|
||||
|
||||
return delta_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
||||
const DiscreteKeys &orig_discrete_keys,
|
||||
|
|
@ -584,6 +602,67 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
|||
return probPrimeTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
||||
const DiscreteKeys &orig_discrete_keys,
|
||||
const boost::shared_ptr<BayesTreeType> &continuousBayesTree) const {
|
||||
// Generate all possible assignments.
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(orig_discrete_keys);
|
||||
|
||||
// Save a copy of the original discrete key ordering
|
||||
DiscreteKeys discrete_keys(orig_discrete_keys);
|
||||
// Reverse discrete keys order for correct tree construction
|
||||
std::reverse(discrete_keys.begin(), discrete_keys.end());
|
||||
|
||||
// Create a decision tree of all the different VectorValues
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
|
||||
this->continuousDelta(discrete_keys, continuousBayesTree, assignments);
|
||||
|
||||
// Get the probPrime tree with the correct leaf probabilities
|
||||
std::vector<double> probPrimes;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
VectorValues delta = *delta_tree(assignment);
|
||||
|
||||
// If VectorValues is empty, it means this is a pruned branch.
|
||||
// Set thr probPrime to 0.0.
|
||||
if (delta.size() == 0) {
|
||||
probPrimes.push_back(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute the error given the delta and the assignment.
|
||||
double error = this->error(delta, assignment);
|
||||
probPrimes.push_back(exp(-error));
|
||||
}
|
||||
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||
return probPrimeTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<Ordering, Ordering>
|
||||
HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
|
||||
const Ordering &ordering) const {
|
||||
KeySet all_continuous_keys = this->continuousKeys();
|
||||
KeySet all_discrete_keys = this->discreteKeys();
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
|
||||
for (auto &&key : ordering) {
|
||||
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
|
||||
key) != all_continuous_keys.end()) {
|
||||
continuous_ordering.push_back(key);
|
||||
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(),
|
||||
key) != all_discrete_keys.end()) {
|
||||
discrete_ordering.push_back(key);
|
||||
} else {
|
||||
throw std::runtime_error("Key in ordering not present in factors.");
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(continuous_ordering, discrete_ordering);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateHybridSequential(
|
||||
|
|
@ -640,25 +719,96 @@ boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
|||
HybridGaussianFactorGraph::eliminateSequential(
|
||||
const Ordering &ordering, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
KeySet all_continuous_keys = this->continuousKeys();
|
||||
KeySet all_discrete_keys = this->discreteKeys();
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
|
||||
// Segregate the continuous and the discrete keys
|
||||
for (auto &&key : ordering) {
|
||||
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
|
||||
key) != all_continuous_keys.end()) {
|
||||
continuous_ordering.push_back(key);
|
||||
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(),
|
||||
key) != all_discrete_keys.end()) {
|
||||
discrete_ordering.push_back(key);
|
||||
} else {
|
||||
throw std::runtime_error("Key in ordering not present in factors.");
|
||||
}
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
std::tie(continuous_ordering, discrete_ordering) =
|
||||
this->separateContinuousDiscreteOrdering(ordering);
|
||||
|
||||
return this->eliminateHybridSequential(continuous_ordering, discrete_ordering,
|
||||
function, variableIndex);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateHybridMultifrontal(
|
||||
const boost::optional<Ordering> continuous,
|
||||
const boost::optional<Ordering> discrete, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
Ordering continuous_ordering =
|
||||
continuous ? *continuous : Ordering(this->continuousKeys());
|
||||
Ordering discrete_ordering =
|
||||
discrete ? *discrete : Ordering(this->discreteKeys());
|
||||
|
||||
// Eliminate continuous
|
||||
HybridBayesTree::shared_ptr bayesTree;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesTree, discreteGraph) =
|
||||
BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering,
|
||||
function, variableIndex);
|
||||
|
||||
// Get the last continuous conditional which will have all the discrete
|
||||
Key last_continuous_key =
|
||||
continuous_ordering.at(continuous_ordering.size() - 1);
|
||||
auto last_conditional = (*bayesTree)[last_continuous_key]->conditional();
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
// If not discrete variables, return the eliminated bayes net.
|
||||
if (discrete_keys.size() == 0) {
|
||||
return bayesTree;
|
||||
}
|
||||
|
||||
return this->eliminateHybridSequential(continuous_ordering,
|
||||
discrete_ordering);
|
||||
AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
this->continuousProbPrimes(discrete_keys, bayesTree);
|
||||
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
|
||||
auto updatedBayesTree =
|
||||
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering,
|
||||
function);
|
||||
|
||||
auto discrete_clique = (*updatedBayesTree)[discrete_ordering.at(0)];
|
||||
|
||||
// Set the root of the bayes tree as the discrete clique
|
||||
for (auto node : bayesTree->nodes()) {
|
||||
auto clique = node.second;
|
||||
|
||||
if (clique->conditional()->parents() ==
|
||||
discrete_clique->conditional()->frontals()) {
|
||||
updatedBayesTree->addClique(clique, discrete_clique);
|
||||
|
||||
} else {
|
||||
// Remove the clique from the children of the parents since it will get
|
||||
// added again in addClique.
|
||||
auto clique_it = std::find(clique->parent()->children.begin(),
|
||||
clique->parent()->children.end(), clique);
|
||||
clique->parent()->children.erase(clique_it);
|
||||
updatedBayesTree->addClique(clique, clique->parent());
|
||||
}
|
||||
}
|
||||
return updatedBayesTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateMultifrontal(
|
||||
OptionalOrderingType orderingType, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
return BaseEliminateable::eliminateMultifrontal(orderingType, function,
|
||||
variableIndex);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateMultifrontal(
|
||||
const Ordering &ordering, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
// Segregate the continuous and the discrete keys
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
std::tie(continuous_ordering, discrete_ordering) =
|
||||
this->separateContinuousDiscreteOrdering(ordering);
|
||||
|
||||
return this->eliminateHybridMultifrontal(
|
||||
continuous_ordering, discrete_ordering, function, variableIndex);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -231,6 +231,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
|
||||
const std::vector<DiscreteValues>& assignments) const;
|
||||
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta(
|
||||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesTreeType>& continuousBayesTree,
|
||||
const std::vector<DiscreteValues>& assignments) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the unnormalized probabilities of the continuous variables
|
||||
|
|
@ -244,6 +248,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
AlgebraicDecisionTree<Key> continuousProbPrimes(
|
||||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const;
|
||||
AlgebraicDecisionTree<Key> continuousProbPrimes(
|
||||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesTreeType>& continuousBayesTree) const;
|
||||
|
||||
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
|
||||
const Ordering& ordering) const;
|
||||
|
||||
/**
|
||||
* @brief Custom elimination function which computes the correct
|
||||
|
|
@ -269,6 +279,22 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal(
|
||||
const boost::optional<Ordering> continuous = boost::none,
|
||||
const boost::optional<Ordering> discrete = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
|
||||
OptionalOrderingType orderingType = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
|
||||
const Ordering& ordering,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
|
|
|
|||
Loading…
Reference in New Issue