split up the eliminate method to constituent parts

release/4.3a0
Varun Agrawal 2022-12-04 18:21:22 +05:30
parent addbe2a5a5
commit ae0b3e3902
2 changed files with 107 additions and 76 deletions

View File

@ -47,4 +47,92 @@ bool HybridEliminationTree::equals(const This& other, double tol) const {
return Base::equals(other, tol); return Base::equals(other, tol);
} }
/* ************************************************************************* */
std::pair<boost::shared_ptr<HybridBayesNet>,
boost::shared_ptr<HybridGaussianFactorGraph>>
HybridEliminationTree::eliminateContinuous(Eliminate function) const {
if (continuous_ordering_.size() > 0) {
This continuous_etree(graph_, variable_index_, continuous_ordering_);
return continuous_etree.Base::eliminate(function);
} else {
HybridBayesNet::shared_ptr bayesNet = boost::make_shared<HybridBayesNet>();
HybridGaussianFactorGraph::shared_ptr discreteGraph =
boost::make_shared<HybridGaussianFactorGraph>(graph_);
return std::make_pair(bayesNet, discreteGraph);
}
}
/* ************************************************************************* */
boost::shared_ptr<HybridGaussianFactorGraph>
HybridEliminationTree::addProbPrimes(
const HybridBayesNet::shared_ptr& continuousBayesNet,
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const {
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
// Get the last continuous conditional
// which will have all the discrete keys
HybridConditional::shared_ptr last_conditional =
continuousBayesNet->at(continuousBayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
// DecisionTree for P'(X|M, Z) for all mode sequences M
const AlgebraicDecisionTree<Key> probPrimeTree =
graph_.continuousProbPrimes(discrete_keys, continuousBayesNet);
// Add the model selection factor P(M|Z)
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
}
return discreteGraph;
}
/* ************************************************************************* */
std::pair<boost::shared_ptr<HybridBayesNet>,
boost::shared_ptr<HybridGaussianFactorGraph>>
HybridEliminationTree::eliminateDiscrete(
Eliminate function,
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const {
HybridBayesNet::shared_ptr discreteBayesNet;
HybridGaussianFactorGraph::shared_ptr finalGraph;
if (discrete_ordering_.size() > 0) {
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
discrete_ordering_);
std::tie(discreteBayesNet, finalGraph) =
discrete_etree.Base::eliminate(function);
} else {
discreteBayesNet = boost::make_shared<HybridBayesNet>();
finalGraph = discreteGraph;
}
return std::make_pair(discreteBayesNet, finalGraph);
}
/* ************************************************************************* */
std::pair<boost::shared_ptr<HybridBayesNet>,
boost::shared_ptr<HybridGaussianFactorGraph>>
HybridEliminationTree::eliminate(Eliminate function) const {
// Perform continuous elimination
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
// If we have eliminated continuous variables
// and have discrete variables to eliminate,
// then compute P(X | M, Z)
HybridGaussianFactorGraph::shared_ptr updatedDiscreteGraph =
addProbPrimes(bayesNet, discreteGraph);
// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet;
HybridGaussianFactorGraph::shared_ptr finalGraph;
std::tie(discreteBayesNet, finalGraph) =
eliminateDiscrete(function, updatedDiscreteGraph);
// Add the discrete conditionals to the hybrid conditionals
bayesNet->add(*discreteBayesNet);
return std::make_pair(bayesNet, finalGraph);
}
} // namespace gtsam } // namespace gtsam

View File

@ -80,22 +80,12 @@ class GTSAM_EXPORT HybridEliminationTree
* and the original graph. * and the original graph.
* *
* @param function Elimination function for hybrid elimination. * @param function Elimination function for hybrid elimination.
* @return std::pair<boost::shared_ptr<BayesNetType>, * @return std::pair<boost::shared_ptr<HybridBayesNet>,
* boost::shared_ptr<FactorGraphType> > * boost::shared_ptr<HybridGaussianFactorGraph> >
*/ */
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>> std::pair<boost::shared_ptr<HybridBayesNet>,
eliminateContinuous(Eliminate function) const { boost::shared_ptr<HybridGaussianFactorGraph>>
if (continuous_ordering_.size() > 0) { eliminateContinuous(Eliminate function) const;
This continuous_etree(graph_, variable_index_, continuous_ordering_);
return continuous_etree.Base::eliminate(function);
} else {
BayesNetType::shared_ptr bayesNet = boost::make_shared<BayesNetType>();
FactorGraphType::shared_ptr discreteGraph =
boost::make_shared<FactorGraphType>(graph_);
return std::make_pair(bayesNet, discreteGraph);
}
}
/** /**
* @brief Helper method to eliminate the discrete variables after the * @brief Helper method to eliminate the discrete variables after the
@ -104,75 +94,28 @@ class GTSAM_EXPORT HybridEliminationTree
* If there are no discrete variables, return an empty bayes net and the * If there are no discrete variables, return an empty bayes net and the
* discreteGraph which is passed in. * discreteGraph which is passed in.
* *
* @param function Elimination function * @param function Hybrid elimination function
* @param discreteGraph The factor graph with the factor ϕ(X | M, Z). * @param discreteGraph The factor graph with the factor ϕ(X | M, Z).
* @return std::pair<boost::shared_ptr<BayesNetType>, * @return std::pair<boost::shared_ptr<HybridBayesNet>,
* boost::shared_ptr<FactorGraphType> > * boost::shared_ptr<HybridGaussianFactorGraph> >
*/ */
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>> std::pair<boost::shared_ptr<HybridBayesNet>,
eliminateDiscrete(Eliminate function, boost::shared_ptr<HybridGaussianFactorGraph>>
const FactorGraphType::shared_ptr& discreteGraph) const { eliminateDiscrete(
BayesNetType::shared_ptr discreteBayesNet; Eliminate function,
FactorGraphType::shared_ptr finalGraph; const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const;
if (discrete_ordering_.size() > 0) {
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
discrete_ordering_);
std::tie(discreteBayesNet, finalGraph) =
discrete_etree.Base::eliminate(function);
} else {
discreteBayesNet = boost::make_shared<BayesNetType>();
finalGraph = discreteGraph;
}
return std::make_pair(discreteBayesNet, finalGraph);
}
/** /**
* @brief Override the EliminationTree eliminate method * @brief Override the EliminationTree eliminate method
* so we can perform hybrid elimination correctly. * so we can perform hybrid elimination correctly.
* *
* @param function * @param function Hybrid elimination function
* @return std::pair<boost::shared_ptr<BayesNetType>, * @return std::pair<boost::shared_ptr<HybridBayesNet>,
* boost::shared_ptr<FactorGraphType> > * boost::shared_ptr<HybridGaussianFactorGraph> >
*/ */
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>> std::pair<boost::shared_ptr<HybridBayesNet>,
eliminate(Eliminate function) const { boost::shared_ptr<HybridGaussianFactorGraph>>
// Perform continuous elimination eliminate(Eliminate function) const;
BayesNetType::shared_ptr bayesNet;
FactorGraphType::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
// If we have eliminated continuous variables
// and have discrete variables to eliminate,
// then compute P(X | M, Z)
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
// Get the last continuous conditional
// which will have all the discrete keys
HybridConditional::shared_ptr last_conditional =
bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
// DecisionTree for P'(X|M, Z) for all mode sequences M
const AlgebraicDecisionTree<Key> probPrimeTree =
graph_.continuousProbPrimes(discrete_keys, bayesNet);
// Add the model selection factor P(M|Z)
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
}
// Perform discrete elimination
BayesNetType::shared_ptr discreteBayesNet;
FactorGraphType::shared_ptr finalGraph;
std::tie(discreteBayesNet, finalGraph) =
eliminateDiscrete(function, discreteGraph);
// Add the discrete conditionals to the hybrid conditionals
bayesNet->add(*discreteBayesNet);
return std::make_pair(bayesNet, finalGraph);
}
Ordering continuousOrdering() const { return continuous_ordering_; } Ordering continuousOrdering() const { return continuous_ordering_; }
Ordering discreteOrdering() const { return discrete_ordering_; } Ordering discreteOrdering() const { return discrete_ordering_; }