override eliminate in HybridJunctionTree
parent
15fffeb18a
commit
addbe2a5a5
|
@ -550,98 +550,5 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
|
|||
return std::make_pair(continuous_ordering, discrete_ordering);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateHybridMultifrontal(
|
||||
const boost::optional<Ordering> continuous,
|
||||
const boost::optional<Ordering> discrete, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
const Ordering continuous_ordering =
|
||||
continuous ? *continuous : Ordering(this->continuousKeys());
|
||||
const Ordering discrete_ordering =
|
||||
discrete ? *discrete : Ordering(this->discreteKeys());
|
||||
|
||||
// Eliminate continuous
|
||||
HybridBayesTree::shared_ptr bayesTree;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesTree, discreteGraph) =
|
||||
BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering,
|
||||
function, variableIndex);
|
||||
|
||||
// Get the last continuous conditional which will have all the discrete
|
||||
const Key last_continuous_key = continuous_ordering.back();
|
||||
HybridConditional::shared_ptr last_conditional =
|
||||
(*bayesTree)[last_continuous_key]->conditional();
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
// If not discrete variables, return the eliminated bayes net.
|
||||
if (discrete_keys.size() == 0) {
|
||||
return bayesTree;
|
||||
}
|
||||
|
||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
this->continuousProbPrimes(discrete_keys, bayesTree);
|
||||
|
||||
// Add the model selection factor P(M|Z)
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
|
||||
// Eliminate discrete variables to get the discrete bayes tree.
|
||||
// This bayes tree will be updated with the
|
||||
// continuous variables as the child nodes.
|
||||
HybridBayesTree::shared_ptr updatedBayesTree =
|
||||
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering,
|
||||
function);
|
||||
|
||||
// Get the clique with all the discrete keys.
|
||||
// There should only be 1 clique.
|
||||
const HybridBayesTree::sharedClique discrete_clique =
|
||||
(*updatedBayesTree)[discrete_ordering.at(0)];
|
||||
|
||||
std::set<HybridBayesTreeClique::shared_ptr> clique_set;
|
||||
for (auto node : bayesTree->nodes()) {
|
||||
clique_set.insert(node.second);
|
||||
}
|
||||
|
||||
// Set the root of the bayes tree as the discrete clique
|
||||
for (auto clique : clique_set) {
|
||||
if (clique->conditional()->parents() ==
|
||||
discrete_clique->conditional()->frontals()) {
|
||||
updatedBayesTree->addClique(clique, discrete_clique);
|
||||
|
||||
} else {
|
||||
// Remove the clique from the children of the parents since it will get
|
||||
// added again in addClique.
|
||||
auto clique_it = std::find(clique->parent()->children.begin(),
|
||||
clique->parent()->children.end(), clique);
|
||||
clique->parent()->children.erase(clique_it);
|
||||
updatedBayesTree->addClique(clique, clique->parent());
|
||||
}
|
||||
}
|
||||
return updatedBayesTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateMultifrontal(
|
||||
OptionalOrderingType orderingType, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
return BaseEliminateable::eliminateMultifrontal(orderingType, function,
|
||||
variableIndex);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
|
||||
HybridGaussianFactorGraph::eliminateMultifrontal(
|
||||
const Ordering &ordering, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
// Segregate the continuous and the discrete keys
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
std::tie(continuous_ordering, discrete_ordering) =
|
||||
this->separateContinuousDiscreteOrdering(ordering);
|
||||
|
||||
return this->eliminateHybridMultifrontal(
|
||||
continuous_ordering, discrete_ordering, function, variableIndex);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -302,31 +302,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
|
||||
const Ordering& ordering) const;
|
||||
|
||||
/**
|
||||
* @brief Custom elimination function which computes the correct
|
||||
* continuous probabilities. Returns a bayes tree.
|
||||
*
|
||||
* @param continuous Optional ordering for all continuous variables.
|
||||
* @param discrete Optional ordering for all discrete variables.
|
||||
* @return boost::shared_ptr<BayesTreeType>
|
||||
*/
|
||||
boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal(
|
||||
const boost::optional<Ordering> continuous = boost::none,
|
||||
const boost::optional<Ordering> discrete = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/// Multifrontal elimination overload for hybrid
|
||||
boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
|
||||
OptionalOrderingType orderingType = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/// Multifrontal elimination overload for hybrid
|
||||
boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
|
||||
const Ordering& ordering,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
|
|
|
@ -142,7 +142,8 @@ struct HybridConstructorTraversalData {
|
|||
|
||||
/* ************************************************************************* */
|
||||
HybridJunctionTree::HybridJunctionTree(
|
||||
const HybridEliminationTree& eliminationTree) {
|
||||
const HybridEliminationTree& eliminationTree)
|
||||
: etree_(eliminationTree) {
|
||||
gttic(JunctionTree_FromEliminationTree);
|
||||
// Here we rely on the BayesNet having been produced by this elimination tree,
|
||||
// such that the conditionals are arranged in DFS post-order. We traverse the
|
||||
|
@ -169,4 +170,140 @@ HybridJunctionTree::HybridJunctionTree(
|
|||
Base::remainingFactors_ = eliminationTree.remainingFactors();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
HybridJunctionTree::eliminateContinuous(
|
||||
const Eliminate& function, const HybridGaussianFactorGraph& graph,
|
||||
const Ordering& continuous_ordering) const {
|
||||
HybridBayesTree::shared_ptr continuousBayesTree;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
|
||||
if (continuous_ordering.size() > 0) {
|
||||
HybridEliminationTree continuous_etree(graph, etree_.variableIndex(),
|
||||
continuous_ordering);
|
||||
|
||||
This continuous_junction_tree(continuous_etree);
|
||||
std::tie(continuousBayesTree, discreteGraph) =
|
||||
continuous_junction_tree.Base::eliminate(function);
|
||||
|
||||
} else {
|
||||
continuousBayesTree = boost::make_shared<HybridBayesTree>();
|
||||
discreteGraph = boost::make_shared<HybridGaussianFactorGraph>(graph);
|
||||
}
|
||||
|
||||
return std::make_pair(continuousBayesTree, discreteGraph);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
HybridJunctionTree::eliminateDiscrete(
|
||||
const Eliminate& function,
|
||||
const HybridBayesTree::shared_ptr& continuousBayesTree,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
|
||||
const Ordering& discrete_ordering) const {
|
||||
HybridBayesTree::shared_ptr updatedBayesTree;
|
||||
HybridGaussianFactorGraph::shared_ptr finalGraph;
|
||||
if (discrete_ordering.size() > 0) {
|
||||
HybridEliminationTree discrete_etree(
|
||||
*discreteGraph, VariableIndex(*discreteGraph), discrete_ordering);
|
||||
|
||||
This discrete_junction_tree(discrete_etree);
|
||||
|
||||
std::tie(updatedBayesTree, finalGraph) =
|
||||
discrete_junction_tree.Base::eliminate(function);
|
||||
|
||||
// Get the clique with all the discrete keys.
|
||||
// There should only be 1 clique.
|
||||
const HybridBayesTree::sharedClique discrete_clique =
|
||||
(*updatedBayesTree)[discrete_ordering.at(0)];
|
||||
|
||||
std::set<HybridBayesTreeClique::shared_ptr> clique_set;
|
||||
for (auto node : continuousBayesTree->nodes()) {
|
||||
clique_set.insert(node.second);
|
||||
}
|
||||
|
||||
// Set the root of the bayes tree as the discrete clique
|
||||
for (auto clique : clique_set) {
|
||||
if (clique->conditional()->parents() ==
|
||||
discrete_clique->conditional()->frontals()) {
|
||||
updatedBayesTree->addClique(clique, discrete_clique);
|
||||
|
||||
} else {
|
||||
if (clique->parent()) {
|
||||
// Remove the clique from the children of the parents since it will
|
||||
// get added again in addClique.
|
||||
auto clique_it = std::find(clique->parent()->children.begin(),
|
||||
clique->parent()->children.end(), clique);
|
||||
clique->parent()->children.erase(clique_it);
|
||||
updatedBayesTree->addClique(clique, clique->parent());
|
||||
} else {
|
||||
updatedBayesTree->addClique(clique);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
updatedBayesTree = continuousBayesTree;
|
||||
finalGraph = discreteGraph;
|
||||
}
|
||||
|
||||
return std::make_pair(updatedBayesTree, finalGraph);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph> HybridJunctionTree::addProbPrimes(
|
||||
const HybridGaussianFactorGraph& graph,
|
||||
const HybridBayesTree::shared_ptr& continuousBayesTree,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
|
||||
const Ordering& continuous_ordering,
|
||||
const Ordering& discrete_ordering) const {
|
||||
// If we have eliminated continuous variables
|
||||
// and have discrete variables to eliminate,
|
||||
// then compute P(X | M, Z)
|
||||
if (continuous_ordering.size() > 0 && discrete_ordering.size() > 0) {
|
||||
// Collect all the discrete keys
|
||||
DiscreteKeys discrete_keys;
|
||||
for (auto node : continuousBayesTree->nodes()) {
|
||||
auto node_dkeys = node.second->conditional()->discreteKeys();
|
||||
discrete_keys.insert(discrete_keys.end(), node_dkeys.begin(),
|
||||
node_dkeys.end());
|
||||
}
|
||||
// Remove duplicates and convert back to DiscreteKeys
|
||||
std::set<DiscreteKey> dkeys_set(discrete_keys.begin(), discrete_keys.end());
|
||||
discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end());
|
||||
|
||||
// DecisionTree for P'(X|M, Z) for all mode sequences M
|
||||
const AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
graph.continuousProbPrimes(discrete_keys, continuousBayesTree);
|
||||
|
||||
// Add the model selection factor P(M|Z)
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
}
|
||||
|
||||
return discreteGraph;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::pair<HybridBayesTree::shared_ptr, HybridGaussianFactorGraph::shared_ptr>
|
||||
HybridJunctionTree::eliminate(const Eliminate& function) const {
|
||||
Ordering continuous_ordering = etree_.continuousOrdering();
|
||||
Ordering discrete_ordering = etree_.discreteOrdering();
|
||||
|
||||
FactorGraphType graph = etree_.graph();
|
||||
|
||||
// Eliminate continuous
|
||||
BayesTreeType::shared_ptr continuousBayesTree;
|
||||
FactorGraphType::shared_ptr discreteGraph;
|
||||
std::tie(continuousBayesTree, discreteGraph) =
|
||||
this->eliminateContinuous(function, graph, continuous_ordering);
|
||||
|
||||
FactorGraphType::shared_ptr updatedDiscreteGraph =
|
||||
this->addProbPrimes(graph, continuousBayesTree, discreteGraph,
|
||||
continuous_ordering, discrete_ordering);
|
||||
|
||||
// Eliminate discrete variables to get the discrete bayes tree.
|
||||
return this->eliminateDiscrete(function, continuousBayesTree,
|
||||
updatedDiscreteGraph, discrete_ordering);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -51,6 +51,11 @@ class HybridEliminationTree;
|
|||
*/
|
||||
class GTSAM_EXPORT HybridJunctionTree
|
||||
: public JunctionTree<HybridBayesTree, HybridGaussianFactorGraph> {
|
||||
/// Record the elimination tree for use in hybrid elimination.
|
||||
HybridEliminationTree etree_;
|
||||
/// Store the provided variable index.
|
||||
VariableIndex variable_index_;
|
||||
|
||||
public:
|
||||
typedef JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>
|
||||
Base; ///< Base class
|
||||
|
@ -67,6 +72,66 @@ class GTSAM_EXPORT HybridJunctionTree
|
|||
* @return The elimination tree
|
||||
*/
|
||||
HybridJunctionTree(const HybridEliminationTree& eliminationTree);
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param function
|
||||
* @param graph
|
||||
* @param continuous_ordering
|
||||
* @return std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
* boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
*/
|
||||
std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
eliminateContinuous(const Eliminate& function,
|
||||
const HybridGaussianFactorGraph& graph,
|
||||
const Ordering& continuous_ordering) const;
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param function
|
||||
* @param continuousBayesTree
|
||||
* @param discreteGraph
|
||||
* @param discrete_ordering
|
||||
* @return std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
* boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
*/
|
||||
std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
eliminateDiscrete(const Eliminate& function,
|
||||
const HybridBayesTree::shared_ptr& continuousBayesTree,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
|
||||
const Ordering& discrete_ordering) const;
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param graph
|
||||
* @param continuousBayesTree
|
||||
* @param discreteGraph
|
||||
* @param continuous_ordering
|
||||
* @param discrete_ordering
|
||||
* @return boost::shared_ptr<HybridGaussianFactorGraph>
|
||||
*/
|
||||
boost::shared_ptr<HybridGaussianFactorGraph> addProbPrimes(
|
||||
const HybridGaussianFactorGraph& graph,
|
||||
const HybridBayesTree::shared_ptr& continuousBayesTree,
|
||||
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
|
||||
const Ordering& continuous_ordering,
|
||||
const Ordering& discrete_ordering) const;
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param function
|
||||
* @return std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
* boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
*/
|
||||
std::pair<boost::shared_ptr<HybridBayesTree>,
|
||||
boost::shared_ptr<HybridGaussianFactorGraph>>
|
||||
eliminate(const Eliminate& function) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -182,7 +182,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
|||
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())}));
|
||||
|
||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||
//TODO(Varun) Adding extra discrete variable not connected to continuous variable throws segfault
|
||||
// TODO(Varun) Adding extra discrete variable not connected to continuous
|
||||
// variable throws segfault
|
||||
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||
|
||||
HybridBayesTree::shared_ptr result =
|
||||
|
@ -276,7 +277,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
|
|||
std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full);
|
||||
|
||||
// 9 cliques in the bayes tree and 0 remaining variables to eliminate.
|
||||
EXPECT_LONGS_EQUAL(9, hbt->size());
|
||||
EXPECT_LONGS_EQUAL(7, hbt->size());
|
||||
EXPECT_LONGS_EQUAL(0, remaining->size());
|
||||
|
||||
/*
|
||||
|
|
Loading…
Reference in New Issue