split up the eliminate method to constituent parts
parent
addbe2a5a5
commit
ae0b3e3902
|
|
@ -47,4 +47,92 @@ bool HybridEliminationTree::equals(const This& other, double tol) const {
|
||||||
return Base::equals(other, tol);
|
return Base::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
|
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||||
|
HybridEliminationTree::eliminateContinuous(Eliminate function) const {
|
||||||
|
if (continuous_ordering_.size() > 0) {
|
||||||
|
This continuous_etree(graph_, variable_index_, continuous_ordering_);
|
||||||
|
return continuous_etree.Base::eliminate(function);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
HybridBayesNet::shared_ptr bayesNet = boost::make_shared<HybridBayesNet>();
|
||||||
|
HybridGaussianFactorGraph::shared_ptr discreteGraph =
|
||||||
|
boost::make_shared<HybridGaussianFactorGraph>(graph_);
|
||||||
|
return std::make_pair(bayesNet, discreteGraph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
boost::shared_ptr<HybridGaussianFactorGraph>
|
||||||
|
HybridEliminationTree::addProbPrimes(
|
||||||
|
const HybridBayesNet::shared_ptr& continuousBayesNet,
|
||||||
|
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const {
|
||||||
|
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
|
||||||
|
// Get the last continuous conditional
|
||||||
|
// which will have all the discrete keys
|
||||||
|
HybridConditional::shared_ptr last_conditional =
|
||||||
|
continuousBayesNet->at(continuousBayesNet->size() - 1);
|
||||||
|
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||||
|
|
||||||
|
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||||
|
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||||
|
graph_.continuousProbPrimes(discrete_keys, continuousBayesNet);
|
||||||
|
|
||||||
|
// Add the model selection factor P(M|Z)
|
||||||
|
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||||
|
}
|
||||||
|
return discreteGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
|
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||||
|
HybridEliminationTree::eliminateDiscrete(
|
||||||
|
Eliminate function,
|
||||||
|
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const {
|
||||||
|
HybridBayesNet::shared_ptr discreteBayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr finalGraph;
|
||||||
|
if (discrete_ordering_.size() > 0) {
|
||||||
|
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
|
||||||
|
discrete_ordering_);
|
||||||
|
|
||||||
|
std::tie(discreteBayesNet, finalGraph) =
|
||||||
|
discrete_etree.Base::eliminate(function);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
discreteBayesNet = boost::make_shared<HybridBayesNet>();
|
||||||
|
finalGraph = discreteGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(discreteBayesNet, finalGraph);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
|
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||||
|
HybridEliminationTree::eliminate(Eliminate function) const {
|
||||||
|
// Perform continuous elimination
|
||||||
|
HybridBayesNet::shared_ptr bayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||||
|
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
|
||||||
|
|
||||||
|
// If we have eliminated continuous variables
|
||||||
|
// and have discrete variables to eliminate,
|
||||||
|
// then compute P(X | M, Z)
|
||||||
|
HybridGaussianFactorGraph::shared_ptr updatedDiscreteGraph =
|
||||||
|
addProbPrimes(bayesNet, discreteGraph);
|
||||||
|
|
||||||
|
// Perform discrete elimination
|
||||||
|
HybridBayesNet::shared_ptr discreteBayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr finalGraph;
|
||||||
|
std::tie(discreteBayesNet, finalGraph) =
|
||||||
|
eliminateDiscrete(function, updatedDiscreteGraph);
|
||||||
|
|
||||||
|
// Add the discrete conditionals to the hybrid conditionals
|
||||||
|
bayesNet->add(*discreteBayesNet);
|
||||||
|
|
||||||
|
return std::make_pair(bayesNet, finalGraph);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -80,22 +80,12 @@ class GTSAM_EXPORT HybridEliminationTree
|
||||||
* and the original graph.
|
* and the original graph.
|
||||||
*
|
*
|
||||||
* @param function Elimination function for hybrid elimination.
|
* @param function Elimination function for hybrid elimination.
|
||||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
* @return std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
* boost::shared_ptr<FactorGraphType> >
|
* boost::shared_ptr<HybridGaussianFactorGraph> >
|
||||||
*/
|
*/
|
||||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
eliminateContinuous(Eliminate function) const {
|
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||||
if (continuous_ordering_.size() > 0) {
|
eliminateContinuous(Eliminate function) const;
|
||||||
This continuous_etree(graph_, variable_index_, continuous_ordering_);
|
|
||||||
return continuous_etree.Base::eliminate(function);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
BayesNetType::shared_ptr bayesNet = boost::make_shared<BayesNetType>();
|
|
||||||
FactorGraphType::shared_ptr discreteGraph =
|
|
||||||
boost::make_shared<FactorGraphType>(graph_);
|
|
||||||
return std::make_pair(bayesNet, discreteGraph);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper method to eliminate the discrete variables after the
|
* @brief Helper method to eliminate the discrete variables after the
|
||||||
|
|
@ -104,75 +94,28 @@ class GTSAM_EXPORT HybridEliminationTree
|
||||||
* If there are no discrete variables, return an empty bayes net and the
|
* If there are no discrete variables, return an empty bayes net and the
|
||||||
* discreteGraph which is passed in.
|
* discreteGraph which is passed in.
|
||||||
*
|
*
|
||||||
* @param function Elimination function
|
* @param function Hybrid elimination function
|
||||||
* @param discreteGraph The factor graph with the factor ϕ(X | M, Z).
|
* @param discreteGraph The factor graph with the factor ϕ(X | M, Z).
|
||||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
* @return std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
* boost::shared_ptr<FactorGraphType> >
|
* boost::shared_ptr<HybridGaussianFactorGraph> >
|
||||||
*/
|
*/
|
||||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
eliminateDiscrete(Eliminate function,
|
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||||
const FactorGraphType::shared_ptr& discreteGraph) const {
|
eliminateDiscrete(
|
||||||
BayesNetType::shared_ptr discreteBayesNet;
|
Eliminate function,
|
||||||
FactorGraphType::shared_ptr finalGraph;
|
const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const;
|
||||||
if (discrete_ordering_.size() > 0) {
|
|
||||||
This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph),
|
|
||||||
discrete_ordering_);
|
|
||||||
|
|
||||||
std::tie(discreteBayesNet, finalGraph) =
|
|
||||||
discrete_etree.Base::eliminate(function);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
discreteBayesNet = boost::make_shared<BayesNetType>();
|
|
||||||
finalGraph = discreteGraph;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_pair(discreteBayesNet, finalGraph);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Override the EliminationTree eliminate method
|
* @brief Override the EliminationTree eliminate method
|
||||||
* so we can perform hybrid elimination correctly.
|
* so we can perform hybrid elimination correctly.
|
||||||
*
|
*
|
||||||
* @param function
|
* @param function Hybrid elimination function
|
||||||
* @return std::pair<boost::shared_ptr<BayesNetType>,
|
* @return std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
* boost::shared_ptr<FactorGraphType> >
|
* boost::shared_ptr<HybridGaussianFactorGraph> >
|
||||||
*/
|
*/
|
||||||
std::pair<boost::shared_ptr<BayesNetType>, boost::shared_ptr<FactorGraphType>>
|
std::pair<boost::shared_ptr<HybridBayesNet>,
|
||||||
eliminate(Eliminate function) const {
|
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||||
// Perform continuous elimination
|
eliminate(Eliminate function) const;
|
||||||
BayesNetType::shared_ptr bayesNet;
|
|
||||||
FactorGraphType::shared_ptr discreteGraph;
|
|
||||||
std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function);
|
|
||||||
|
|
||||||
// If we have eliminated continuous variables
|
|
||||||
// and have discrete variables to eliminate,
|
|
||||||
// then compute P(X | M, Z)
|
|
||||||
if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) {
|
|
||||||
// Get the last continuous conditional
|
|
||||||
// which will have all the discrete keys
|
|
||||||
HybridConditional::shared_ptr last_conditional =
|
|
||||||
bayesNet->at(bayesNet->size() - 1);
|
|
||||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
|
||||||
|
|
||||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
|
||||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
|
||||||
graph_.continuousProbPrimes(discrete_keys, bayesNet);
|
|
||||||
|
|
||||||
// Add the model selection factor P(M|Z)
|
|
||||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform discrete elimination
|
|
||||||
BayesNetType::shared_ptr discreteBayesNet;
|
|
||||||
FactorGraphType::shared_ptr finalGraph;
|
|
||||||
std::tie(discreteBayesNet, finalGraph) =
|
|
||||||
eliminateDiscrete(function, discreteGraph);
|
|
||||||
|
|
||||||
// Add the discrete conditionals to the hybrid conditionals
|
|
||||||
bayesNet->add(*discreteBayesNet);
|
|
||||||
|
|
||||||
return std::make_pair(bayesNet, finalGraph);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ordering continuousOrdering() const { return continuous_ordering_; }
|
Ordering continuousOrdering() const { return continuous_ordering_; }
|
||||||
Ordering discreteOrdering() const { return discrete_ordering_; }
|
Ordering discreteOrdering() const { return discrete_ordering_; }
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue