moved sequential elimination code to HybridEliminationTree
parent
3eaf4cc910
commit
cd3cfa0faa
|
|
@ -27,12 +27,20 @@ template class EliminationTree<HybridBayesNet, HybridGaussianFactorGraph>;
|
|||
HybridEliminationTree::HybridEliminationTree(
|
||||
const HybridGaussianFactorGraph& factorGraph,
|
||||
const VariableIndex& structure, const Ordering& order)
|
||||
: Base(factorGraph, structure, order) {}
|
||||
: Base(factorGraph, structure, order),
|
||||
graph_(factorGraph),
|
||||
variable_index_(structure) {
|
||||
// Segregate the continuous and the discrete keys
|
||||
std::tie(continuous_ordering_, discrete_ordering_) =
|
||||
graph_.separateContinuousDiscreteOrdering(order);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridEliminationTree::HybridEliminationTree(
|
||||
const HybridGaussianFactorGraph& factorGraph, const Ordering& order)
|
||||
: Base(factorGraph, order) {}
|
||||
: Base(factorGraph, order),
|
||||
graph_(factorGraph),
|
||||
variable_index_(VariableIndex(factorGraph)) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool HybridEliminationTree::equals(const This& other, double tol) const {
|
||||
|
|
|
|||
|
|
@ -33,6 +33,12 @@ class GTSAM_EXPORT HybridEliminationTree
|
|||
private:
|
||||
friend class ::EliminationTreeTester;
|
||||
|
||||
Ordering continuous_ordering_, discrete_ordering_;
|
||||
/// Used to store the original factor graph to eliminate
|
||||
HybridGaussianFactorGraph graph_;
|
||||
/// Store the provided variable index.
|
||||
VariableIndex variable_index_;
|
||||
|
||||
public:
|
||||
typedef EliminationTree<HybridBayesNet, HybridGaussianFactorGraph>
|
||||
Base; ///< Base class
|
||||
|
|
@ -66,6 +72,107 @@ class GTSAM_EXPORT HybridEliminationTree
|
|||
|
||||
/** Test whether the tree is equal to another */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
/**
|
||||
* @brief Helper method to eliminate continuous variables.
|
||||
*
|
||||
* If no continuous variables exist, return an empty bayes net
|
||||
* and the original graph.
|
||||
*
|
||||
* @param function Elimination function for hybrid elimination.
|
||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
||||
* boost::shared_ptr<FactorGraphType> >
|
||||
*/
|
||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
||||
eliminateContinuous(Eliminate function) const {
|
||||
if (continuous_ordering_.size() > 0) {
|
||||
This continuous_etree(graph_, variable_index_, continuous_ordering_);
|
||||
return continuous_etree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
BayesNetType::shared_ptr bayesNet = boost::make_shared<BayesNetType>();
|
||||
FactorGraphType::shared_ptr discreteGraph =
|
||||
boost::make_shared<FactorGraphType>(graph_);
|
||||
return std::make_pair(bayesNet, discreteGraph);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper method to eliminate the discrete variables after the
|
||||
* continuous variables have been eliminated.
|
||||
*
|
||||
* If there are no discrete variables, return an empty bayes net and the
|
||||
* discreteGraph which is passed in.
|
||||
*
|
||||
* @param function Elimination function
|
||||
* @param discreteGraph The factor graph with the factor ϕ(X | M, Z).
|
||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
||||
* boost::shared_ptr<FactorGraphType> >
|
||||
*/
|
||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
||||
eliminateDiscrete(Eliminate function,
|
||||
const FactorGraphType::shared_ptr& discreteGraph) const {
|
||||
BayesNetType::shared_ptr discreteBayesNet;
|
||||
FactorGraphType::shared_ptr finalGraph;
|
||||
if (discrete_ordering_.size() > 0) {
|
||||
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
|
||||
discrete_ordering_);
|
||||
|
||||
std::tie(discreteBayesNet, finalGraph) =
|
||||
discrete_etree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
discreteBayesNet = boost::make_shared<BayesNetType>();
|
||||
finalGraph = discreteGraph;
|
||||
}
|
||||
|
||||
return std::make_pair(discreteBayesNet, finalGraph);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Override the EliminationTree eliminate method
|
||||
* so we can perform hybrid elimination correctly.
|
||||
*
|
||||
* @param function
|
||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
||||
* boost::shared_ptr<FactorGraphType> >
|
||||
*/
|
||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
||||
eliminate(Eliminate function) const {
|
||||
// Perform continuous elimination
|
||||
BayesNetType::shared_ptr bayesNet;
|
||||
FactorGraphType::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
|
||||
|
||||
// If we have eliminated continuous variables
|
||||
// and have discrete variables to eliminate,
|
||||
// then compute P(X | M, Z)
|
||||
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
|
||||
// Get the last continuous conditional
|
||||
// which will have all the discrete keys
|
||||
HybridConditional::shared_ptr last_conditional =
|
||||
bayesNet->at(bayesNet->size() - 1);
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
graph_.continuousProbPrimes(discrete_keys, bayesNet);
|
||||
|
||||
// Add the model selection factor P(M|Z)
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
}
|
||||
|
||||
// Perform discrete elimination
|
||||
BayesNetType::shared_ptr discreteBayesNet;
|
||||
FactorGraphType::shared_ptr finalGraph;
|
||||
std::tie(discreteBayesNet, finalGraph) =
|
||||
eliminateDiscrete(function, discreteGraph);
|
||||
|
||||
// Add the discrete conditionals to the hybrid conditionals
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return std::make_pair(bayesNet, finalGraph);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
|
|||
/* ************************************************************************ */
|
||||
void HybridFactor::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
std::cout << (s.empty() ? "" : s + "\n");
|
||||
std::cout << s;
|
||||
if (isContinuous_) std::cout << "Continuous ";
|
||||
if (isDiscrete_) std::cout << "Discrete ";
|
||||
if (isHybrid_) std::cout << "Hybrid ";
|
||||
|
|
|
|||
|
|
@ -550,74 +550,6 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
|
|||
return std::make_pair(continuous_ordering, discrete_ordering);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateHybridSequential(
|
||||
const boost::optional<Ordering> continuous,
|
||||
const boost::optional<Ordering> discrete, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
const Ordering continuous_ordering =
|
||||
continuous ? *continuous : Ordering(this->continuousKeys());
|
||||
const Ordering discrete_ordering =
|
||||
discrete ? *discrete : Ordering(this->discreteKeys());
|
||||
|
||||
// Eliminate continuous
|
||||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) =
|
||||
BaseEliminateable::eliminatePartialSequential(continuous_ordering,
|
||||
function, variableIndex);
|
||||
|
||||
// Get the last continuous conditional which will have all the discrete keys
|
||||
HybridConditional::shared_ptr last_conditional =
|
||||
bayesNet->at(bayesNet->size() - 1);
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
// If no discrete variables, return the eliminated bayes net.
|
||||
if (discrete_keys.size() == 0) {
|
||||
return bayesNet;
|
||||
}
|
||||
|
||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
this->continuousProbPrimes(discrete_keys, bayesNet);
|
||||
|
||||
// Add the model selection factor P(M|Z)
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
|
||||
// Perform discrete elimination
|
||||
HybridBayesNet::shared_ptr discreteBayesNet =
|
||||
discreteGraph->BaseEliminateable::eliminateSequential(
|
||||
discrete_ordering, function, variableIndex);
|
||||
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return bayesNet;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateSequential(
|
||||
OptionalOrderingType orderingType, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
return BaseEliminateable::eliminateSequential(orderingType, function,
|
||||
variableIndex);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateSequential(
|
||||
const Ordering &ordering, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
// Segregate the continuous and the discrete keys
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
std::tie(continuous_ordering, discrete_ordering) =
|
||||
this->separateContinuousDiscreteOrdering(ordering);
|
||||
|
||||
return this->eliminateHybridSequential(continuous_ordering, discrete_ordering,
|
||||
function, variableIndex);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateHybridMultifrontal(
|
||||
|
|
|
|||
|
|
@ -302,32 +302,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
|
||||
const Ordering& ordering) const;
|
||||
|
||||
/**
|
||||
* @brief Custom elimination function which computes the correct
|
||||
* continuous probabilities. Returns a bayes net.
|
||||
*
|
||||
* @param continuous Optional ordering for all continuous variables.
|
||||
* @param discrete Optional ordering for all discrete variables.
|
||||
* @return boost::shared_ptr<BayesNetType>
|
||||
*/
|
||||
boost::shared_ptr<BayesNetType> eliminateHybridSequential(
|
||||
const boost::optional<Ordering> continuous = boost::none,
|
||||
const boost::optional<Ordering> discrete = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/// Sequential elimination overload for hybrid
|
||||
boost::shared_ptr<BayesNetType> eliminateSequential(
|
||||
OptionalOrderingType orderingType = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/// Sequential elimination overload for hybrid
|
||||
boost::shared_ptr<BayesNetType> eliminateSequential(
|
||||
const Ordering& ordering,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Custom elimination function which computes the correct
|
||||
* continuous probabilities. Returns a bayes tree.
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
|||
addConditionals(graph, hybridBayesNet_, ordering);
|
||||
|
||||
// Eliminate.
|
||||
auto bayesNetFragment = graph.eliminateHybridSequential();
|
||||
auto bayesNetFragment = graph.eliminateSequential();
|
||||
|
||||
/// Prune
|
||||
if (maxNrLeaves) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue