split up the eliminate method to constituent parts
parent
addbe2a5a5
commit
ae0b3e3902
|
|
@ -47,4 +47,92 @@ bool HybridEliminationTree::equals(const This& other, double tol) const {
|
|||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
HybridEliminationTree::eliminateContinuous(Eliminate function) const {
|
||||
if (continuous_ordering_.size() > 0) {
|
||||
This continuous_etree(graph_, variable_index_, continuous_ordering_);
|
||||
return continuous_etree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
HybridBayesNet::shared_ptr bayesNet = boost::make_shared<HybridBayesNet>();
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph =
|
||||
boost::make_shared<HybridGaussianFactorGraph>(graph_);
|
||||
return std::make_pair(bayesNet, discreteGraph);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>
|
||||
HybridEliminationTree::addProbPrimes(
|
||||
const HybridBayesNet::shared_ptr& continuousBayesNet,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const {
|
||||
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
|
||||
// Get the last continuous conditional
|
||||
// which will have all the discrete keys
|
||||
HybridConditional::shared_ptr last_conditional =
|
||||
continuousBayesNet->at(continuousBayesNet->size() - 1);
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
graph_.continuousProbPrimes(discrete_keys, continuousBayesNet);
|
||||
|
||||
// Add the model selection factor P(M|Z)
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
}
|
||||
return discreteGraph;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
HybridEliminationTree::eliminateDiscrete(
|
||||
Eliminate function,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const {
|
||||
HybridBayesNet::shared_ptr discreteBayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr finalGraph;
|
||||
if (discrete_ordering_.size() > 0) {
|
||||
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
|
||||
discrete_ordering_);
|
||||
|
||||
std::tie(discreteBayesNet, finalGraph) =
|
||||
discrete_etree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
discreteBayesNet = boost::make_shared<HybridBayesNet>();
|
||||
finalGraph = discreteGraph;
|
||||
}
|
||||
|
||||
return std::make_pair(discreteBayesNet, finalGraph);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
HybridEliminationTree::eliminate(Eliminate function) const {
|
||||
// Perform continuous elimination
|
||||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
|
||||
|
||||
// If we have eliminated continuous variables
|
||||
// and have discrete variables to eliminate,
|
||||
// then compute P(X | M, Z)
|
||||
HybridGaussianFactorGraph::shared_ptr updatedDiscreteGraph =
|
||||
addProbPrimes(bayesNet, discreteGraph);
|
||||
|
||||
// Perform discrete elimination
|
||||
HybridBayesNet::shared_ptr discreteBayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr finalGraph;
|
||||
std::tie(discreteBayesNet, finalGraph) =
|
||||
eliminateDiscrete(function, updatedDiscreteGraph);
|
||||
|
||||
// Add the discrete conditionals to the hybrid conditionals
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return std::make_pair(bayesNet, finalGraph);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -80,22 +80,12 @@ class GTSAM_EXPORT HybridEliminationTree
|
|||
* and the original graph.
|
||||
*
|
||||
* @param function Elimination function for hybrid elimination.
|
||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
||||
* boost::shared_ptr<FactorGraphType> >
|
||||
* @return std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
* boost::shared_ptr<HybridGaussianFactorGraph> >
|
||||
*/
|
||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
||||
eliminateContinuous(Eliminate function) const {
|
||||
if (continuous_ordering_.size() > 0) {
|
||||
This continuous_etree(graph_, variable_index_, continuous_ordering_);
|
||||
return continuous_etree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
BayesNetType::shared_ptr bayesNet = boost::make_shared<BayesNetType>();
|
||||
FactorGraphType::shared_ptr discreteGraph =
|
||||
boost::make_shared<FactorGraphType>(graph_);
|
||||
return std::make_pair(bayesNet, discreteGraph);
|
||||
}
|
||||
}
|
||||
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
eliminateContinuous(Eliminate function) const;
|
||||
|
||||
/**
|
||||
* @brief Helper method to eliminate the discrete variables after the
|
||||
|
|
@ -104,75 +94,28 @@ class GTSAM_EXPORT HybridEliminationTree
|
|||
* If there are no discrete variables, return an empty bayes net and the
|
||||
* discreteGraph which is passed in.
|
||||
*
|
||||
* @param function Elimination function
|
||||
* @param function Hybrid elimination function
|
||||
* @param discreteGraph The factor graph with the factor ϕ(X | M, Z).
|
||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
||||
* boost::shared_ptr<FactorGraphType> >
|
||||
* @return std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
* boost::shared_ptr<HybridGaussianFactorGraph> >
|
||||
*/
|
||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
||||
eliminateDiscrete(Eliminate function,
|
||||
const FactorGraphType::shared_ptr& discreteGraph) const {
|
||||
BayesNetType::shared_ptr discreteBayesNet;
|
||||
FactorGraphType::shared_ptr finalGraph;
|
||||
if (discrete_ordering_.size() > 0) {
|
||||
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
|
||||
discrete_ordering_);
|
||||
|
||||
std::tie(discreteBayesNet, finalGraph) =
|
||||
discrete_etree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
discreteBayesNet = boost::make_shared<BayesNetType>();
|
||||
finalGraph = discreteGraph;
|
||||
}
|
||||
|
||||
return std::make_pair(discreteBayesNet, finalGraph);
|
||||
}
|
||||
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
eliminateDiscrete(
|
||||
Eliminate function,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const;
|
||||
|
||||
/**
|
||||
* @brief Override the EliminationTree eliminate method
|
||||
* so we can perform hybrid elimination correctly.
|
||||
*
|
||||
* @param function
|
||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
||||
* boost::shared_ptr<FactorGraphType> >
|
||||
* @param function Hybrid elimination function
|
||||
* @return std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
* boost::shared_ptr<HybridGaussianFactorGraph> >
|
||||
*/
|
||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
||||
eliminate(Eliminate function) const {
|
||||
// Perform continuous elimination
|
||||
BayesNetType::shared_ptr bayesNet;
|
||||
FactorGraphType::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
|
||||
|
||||
// If we have eliminated continuous variables
|
||||
// and have discrete variables to eliminate,
|
||||
// then compute P(X | M, Z)
|
||||
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
|
||||
// Get the last continuous conditional
|
||||
// which will have all the discrete keys
|
||||
HybridConditional::shared_ptr last_conditional =
|
||||
bayesNet->at(bayesNet->size() - 1);
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
graph_.continuousProbPrimes(discrete_keys, bayesNet);
|
||||
|
||||
// Add the model selection factor P(M|Z)
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
}
|
||||
|
||||
// Perform discrete elimination
|
||||
BayesNetType::shared_ptr discreteBayesNet;
|
||||
FactorGraphType::shared_ptr finalGraph;
|
||||
std::tie(discreteBayesNet, finalGraph) =
|
||||
eliminateDiscrete(function, discreteGraph);
|
||||
|
||||
// Add the discrete conditionals to the hybrid conditionals
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return std::make_pair(bayesNet, finalGraph);
|
||||
}
|
||||
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
eliminate(Eliminate function) const;
|
||||
|
||||
Ordering continuousOrdering() const { return continuous_ordering_; }
|
||||
Ordering discreteOrdering() const { return discrete_ordering_; }
|
||||
|
|
|
|||
Loading…
Reference in New Issue