moved sequential elimination code to HybridEliminationTree

release/4.3a0
Varun Agrawal 2022-12-03 17:14:11 +05:30
parent 3eaf4cc910
commit cd3cfa0faa
6 changed files with 119 additions and 98 deletions

View File

@ -27,12 +27,20 @@ template class EliminationTree<HybridBayesNet, HybridGaussianFactorGraph>;
HybridEliminationTree::HybridEliminationTree(
const HybridGaussianFactorGraph& factorGraph,
const VariableIndex& structure, const Ordering& order)
: Base(factorGraph, structure, order) {}
: Base(factorGraph, structure, order),
graph_(factorGraph),
variable_index_(structure) {
// Segregate the continuous and the discrete keys
std::tie(continuous_ordering_, discrete_ordering_) =
graph_.separateContinuousDiscreteOrdering(order);
}
/* ************************************************************************* */
HybridEliminationTree::HybridEliminationTree(
const HybridGaussianFactorGraph& factorGraph, const Ordering& order)
: Base(factorGraph, order) {}
: Base(factorGraph, order),
graph_(factorGraph),
variable_index_(VariableIndex(factorGraph)) {}
/* ************************************************************************* */
bool HybridEliminationTree::equals(const This& other, double tol) const {

View File

@ -33,6 +33,12 @@ class GTSAM_EXPORT HybridEliminationTree
private:
friend class ::EliminationTreeTester;
Ordering continuous_ordering_, discrete_ordering_;
/// Used to store the original factor graph to eliminate
HybridGaussianFactorGraph graph_;
/// Store the provided variable index.
VariableIndex variable_index_;
public:
typedef EliminationTree<HybridBayesNet, HybridGaussianFactorGraph>
Base; ///< Base class
@ -66,6 +72,107 @@ class GTSAM_EXPORT HybridEliminationTree
/** Test whether the tree is equal to another */
bool equals(const This& other, double tol = 1e-9) const;
/**
* @brief Helper method to eliminate continuous variables.
*
* If no continuous variables exist, return an empty bayes net
* and the original graph.
*
* @param function Elimination function for hybrid elimination.
* @return std::pair<boost::shared_ptr<BayesNetType>,
* boost::shared_ptr<FactorGraphType> >
*/
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
eliminateContinuous(Eliminate function) const {
if (continuous_ordering_.size() > 0) {
This continuous_etree(graph_, variable_index_, continuous_ordering_);
return continuous_etree.Base::eliminate(function);
} else {
BayesNetType::shared_ptr bayesNet = boost::make_shared<BayesNetType>();
FactorGraphType::shared_ptr discreteGraph =
boost::make_shared<FactorGraphType>(graph_);
return std::make_pair(bayesNet, discreteGraph);
}
}
/**
* @brief Helper method to eliminate the discrete variables after the
* continuous variables have been eliminated.
*
* If there are no discrete variables, return an empty bayes net and the
* discreteGraph which is passed in.
*
* @param function Elimination function
* @param discreteGraph The factor graph with the factor ϕ(X | M, Z).
* @return std::pair<boost::shared_ptr<BayesNetType>,
* boost::shared_ptr<FactorGraphType> >
*/
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
eliminateDiscrete(Eliminate function,
const FactorGraphType::shared_ptr& discreteGraph) const {
BayesNetType::shared_ptr discreteBayesNet;
FactorGraphType::shared_ptr finalGraph;
if (discrete_ordering_.size() > 0) {
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
discrete_ordering_);
std::tie(discreteBayesNet, finalGraph) =
discrete_etree.Base::eliminate(function);
} else {
discreteBayesNet = boost::make_shared<BayesNetType>();
finalGraph = discreteGraph;
}
return std::make_pair(discreteBayesNet, finalGraph);
}
/**
* @brief Override the EliminationTree eliminate method
* so we can perform hybrid elimination correctly.
*
* @param function
* @return std::pair<boost::shared_ptr<BayesNetType>,
* boost::shared_ptr<FactorGraphType> >
*/
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
eliminate(Eliminate function) const {
// Perform continuous elimination
BayesNetType::shared_ptr bayesNet;
FactorGraphType::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
// If we have eliminated continuous variables
// and have discrete variables to eliminate,
// then compute P(X | M, Z)
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
// Get the last continuous conditional
// which will have all the discrete keys
HybridConditional::shared_ptr last_conditional =
bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
// DecisionTree for P'(X|M, Z) for all mode sequences M
const AlgebraicDecisionTree<Key> probPrimeTree =
graph_.continuousProbPrimes(discrete_keys, bayesNet);
// Add the model selection factor P(M|Z)
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
}
// Perform discrete elimination
BayesNetType::shared_ptr discreteBayesNet;
FactorGraphType::shared_ptr finalGraph;
std::tie(discreteBayesNet, finalGraph) =
eliminateDiscrete(function, discreteGraph);
// Add the discrete conditionals to the hybrid conditionals
bayesNet->add(*discreteBayesNet);
return std::make_pair(bayesNet, finalGraph);
}
};
} // namespace gtsam

View File

@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
/* ************************************************************************ */
void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << (s.empty() ? "" : s + "\n");
std::cout << s;
if (isContinuous_) std::cout << "Continuous ";
if (isDiscrete_) std::cout << "Discrete ";
if (isHybrid_) std::cout << "Hybrid ";

View File

@ -550,74 +550,6 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
return std::make_pair(continuous_ordering, discrete_ordering);
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
const Ordering continuous_ordering =
continuous ? *continuous : Ordering(this->continuousKeys());
const Ordering discrete_ordering =
discrete ? *discrete : Ordering(this->discreteKeys());
// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(continuous_ordering,
function, variableIndex);
// Get the last continuous conditional which will have all the discrete keys
HybridConditional::shared_ptr last_conditional =
bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
// If no discrete variables, return the eliminated bayes net.
if (discrete_keys.size() == 0) {
return bayesNet;
}
// DecisionTree for P'(X|M, Z) for all mode sequences M
const AlgebraicDecisionTree<Key> probPrimeTree =
this->continuousProbPrimes(discrete_keys, bayesNet);
// Add the model selection factor P(M|Z)
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet =
discreteGraph->BaseEliminateable::eliminateSequential(
discrete_ordering, function, variableIndex);
bayesNet->add(*discreteBayesNet);
return bayesNet;
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
OptionalOrderingType orderingType, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
return BaseEliminateable::eliminateSequential(orderingType, function,
variableIndex);
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
// Segregate the continuous and the discrete keys
Ordering continuous_ordering, discrete_ordering;
std::tie(continuous_ordering, discrete_ordering) =
this->separateContinuousDiscreteOrdering(ordering);
return this->eliminateHybridSequential(continuous_ordering, discrete_ordering,
function, variableIndex);
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateHybridMultifrontal(

View File

@ -302,32 +302,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
const Ordering& ordering) const;
/**
* @brief Custom elimination function which computes the correct
* continuous probabilities. Returns a bayes net.
*
* @param continuous Optional ordering for all continuous variables.
* @param discrete Optional ordering for all discrete variables.
* @return boost::shared_ptr<BayesNetType>
*/
boost::shared_ptr<BayesNetType> eliminateHybridSequential(
const boost::optional<Ordering> continuous = boost::none,
const boost::optional<Ordering> discrete = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/// Sequential elimination overload for hybrid
boost::shared_ptr<BayesNetType> eliminateSequential(
OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/// Sequential elimination overload for hybrid
boost::shared_ptr<BayesNetType> eliminateSequential(
const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/**
* @brief Custom elimination function which computes the correct
* continuous probabilities. Returns a bayes tree.

View File

@ -32,7 +32,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
addConditionals(graph, hybridBayesNet_, ordering);
// Eliminate.
auto bayesNetFragment = graph.eliminateHybridSequential();
auto bayesNetFragment = graph.eliminateSequential();
/// Prune
if (maxNrLeaves) {