override eliminate in HybridJunctionTree

release/4.3a0
Varun Agrawal 2022-12-04 14:55:17 +05:30
parent 15fffeb18a
commit addbe2a5a5
5 changed files with 209 additions and 123 deletions

View File

@ -550,98 +550,5 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
return std::make_pair(continuous_ordering, discrete_ordering); return std::make_pair(continuous_ordering, discrete_ordering);
} }
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateHybridMultifrontal(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
const Ordering continuous_ordering =
continuous ? *continuous : Ordering(this->continuousKeys());
const Ordering discrete_ordering =
discrete ? *discrete : Ordering(this->discreteKeys());
// Eliminate continuous
HybridBayesTree::shared_ptr bayesTree;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesTree, discreteGraph) =
BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering,
function, variableIndex);
// Get the last continuous conditional which will have all the discrete
const Key last_continuous_key = continuous_ordering.back();
HybridConditional::shared_ptr last_conditional =
(*bayesTree)[last_continuous_key]->conditional();
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
// If not discrete variables, return the eliminated bayes net.
if (discrete_keys.size() == 0) {
return bayesTree;
}
// DecisionTree for P'(X|M, Z) for all mode sequences M
const AlgebraicDecisionTree<Key> probPrimeTree =
this->continuousProbPrimes(discrete_keys, bayesTree);
// Add the model selection factor P(M|Z)
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
// Eliminate discrete variables to get the discrete bayes tree.
// This bayes tree will be updated with the
// continuous variables as the child nodes.
HybridBayesTree::shared_ptr updatedBayesTree =
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering,
function);
// Get the clique with all the discrete keys.
// There should only be 1 clique.
const HybridBayesTree::sharedClique discrete_clique =
(*updatedBayesTree)[discrete_ordering.at(0)];
std::set<HybridBayesTreeClique::shared_ptr> clique_set;
for (auto node : bayesTree->nodes()) {
clique_set.insert(node.second);
}
// Set the root of the bayes tree as the discrete clique
for (auto clique : clique_set) {
if (clique->conditional()->parents() ==
discrete_clique->conditional()->frontals()) {
updatedBayesTree->addClique(clique, discrete_clique);
} else {
// Remove the clique from the children of the parents since it will get
// added again in addClique.
auto clique_it = std::find(clique->parent()->children.begin(),
clique->parent()->children.end(), clique);
clique->parent()->children.erase(clique_it);
updatedBayesTree->addClique(clique, clique->parent());
}
}
return updatedBayesTree;
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateMultifrontal(
OptionalOrderingType orderingType, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
return BaseEliminateable::eliminateMultifrontal(orderingType, function,
variableIndex);
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateMultifrontal(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
// Segregate the continuous and the discrete keys
Ordering continuous_ordering, discrete_ordering;
std::tie(continuous_ordering, discrete_ordering) =
this->separateContinuousDiscreteOrdering(ordering);
return this->eliminateHybridMultifrontal(
continuous_ordering, discrete_ordering, function, variableIndex);
}
} // namespace gtsam } // namespace gtsam

View File

@ -302,31 +302,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering( std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
const Ordering& ordering) const; const Ordering& ordering) const;
/**
* @brief Custom elimination function which computes the correct
* continuous probabilities. Returns a bayes tree.
*
* @param continuous Optional ordering for all continuous variables.
* @param discrete Optional ordering for all discrete variables.
* @return boost::shared_ptr<BayesTreeType>
*/
boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal(
const boost::optional<Ordering> continuous = boost::none,
const boost::optional<Ordering> discrete = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/// Multifrontal elimination overload for hybrid
boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/// Multifrontal elimination overload for hybrid
boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are

View File

@ -142,7 +142,8 @@ struct HybridConstructorTraversalData {
/* ************************************************************************* */ /* ************************************************************************* */
HybridJunctionTree::HybridJunctionTree( HybridJunctionTree::HybridJunctionTree(
const HybridEliminationTree& eliminationTree) { const HybridEliminationTree& eliminationTree)
: etree_(eliminationTree) {
gttic(JunctionTree_FromEliminationTree); gttic(JunctionTree_FromEliminationTree);
// Here we rely on the BayesNet having been produced by this elimination tree, // Here we rely on the BayesNet having been produced by this elimination tree,
// such that the conditionals are arranged in DFS post-order. We traverse the // such that the conditionals are arranged in DFS post-order. We traverse the
@ -169,4 +170,140 @@ HybridJunctionTree::HybridJunctionTree(
Base::remainingFactors_ = eliminationTree.remainingFactors(); Base::remainingFactors_ = eliminationTree.remainingFactors();
} }
/* ************************************************************************* */
std::pair<boost::shared_ptr<HybridBayesTree>,
boost::shared_ptr<HybridGaussianFactorGraph>>
HybridJunctionTree::eliminateContinuous(
const Eliminate& function, const HybridGaussianFactorGraph& graph,
const Ordering& continuous_ordering) const {
HybridBayesTree::shared_ptr continuousBayesTree;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
if (continuous_ordering.size() > 0) {
HybridEliminationTree continuous_etree(graph, etree_.variableIndex(),
continuous_ordering);
This continuous_junction_tree(continuous_etree);
std::tie(continuousBayesTree, discreteGraph) =
continuous_junction_tree.Base::eliminate(function);
} else {
continuousBayesTree = boost::make_shared<HybridBayesTree>();
discreteGraph = boost::make_shared<HybridGaussianFactorGraph>(graph);
}
return std::make_pair(continuousBayesTree, discreteGraph);
}
/* ************************************************************************* */
std::pair<boost::shared_ptr<HybridBayesTree>,
boost::shared_ptr<HybridGaussianFactorGraph>>
HybridJunctionTree::eliminateDiscrete(
const Eliminate& function,
const HybridBayesTree::shared_ptr& continuousBayesTree,
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
const Ordering& discrete_ordering) const {
HybridBayesTree::shared_ptr updatedBayesTree;
HybridGaussianFactorGraph::shared_ptr finalGraph;
if (discrete_ordering.size() > 0) {
HybridEliminationTree discrete_etree(
*discreteGraph, VariableIndex(*discreteGraph), discrete_ordering);
This discrete_junction_tree(discrete_etree);
std::tie(updatedBayesTree, finalGraph) =
discrete_junction_tree.Base::eliminate(function);
// Get the clique with all the discrete keys.
// There should only be 1 clique.
const HybridBayesTree::sharedClique discrete_clique =
(*updatedBayesTree)[discrete_ordering.at(0)];
std::set<HybridBayesTreeClique::shared_ptr> clique_set;
for (auto node : continuousBayesTree->nodes()) {
clique_set.insert(node.second);
}
// Set the root of the bayes tree as the discrete clique
for (auto clique : clique_set) {
if (clique->conditional()->parents() ==
discrete_clique->conditional()->frontals()) {
updatedBayesTree->addClique(clique, discrete_clique);
} else {
if (clique->parent()) {
// Remove the clique from the children of the parents since it will
// get added again in addClique.
auto clique_it = std::find(clique->parent()->children.begin(),
clique->parent()->children.end(), clique);
clique->parent()->children.erase(clique_it);
updatedBayesTree->addClique(clique, clique->parent());
} else {
updatedBayesTree->addClique(clique);
}
}
}
} else {
updatedBayesTree = continuousBayesTree;
finalGraph = discreteGraph;
}
return std::make_pair(updatedBayesTree, finalGraph);
}
/* ************************************************************************* */
boost::shared_ptr<HybridGaussianFactorGraph> HybridJunctionTree::addProbPrimes(
const HybridGaussianFactorGraph& graph,
const HybridBayesTree::shared_ptr& continuousBayesTree,
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
const Ordering& continuous_ordering,
const Ordering& discrete_ordering) const {
// If we have eliminated continuous variables
// and have discrete variables to eliminate,
// then compute P(X | M, Z)
if (continuous_ordering.size() > 0 && discrete_ordering.size() > 0) {
// Collect all the discrete keys
DiscreteKeys discrete_keys;
for (auto node : continuousBayesTree->nodes()) {
auto node_dkeys = node.second->conditional()->discreteKeys();
discrete_keys.insert(discrete_keys.end(), node_dkeys.begin(),
node_dkeys.end());
}
// Remove duplicates and convert back to DiscreteKeys
std::set<DiscreteKey> dkeys_set(discrete_keys.begin(), discrete_keys.end());
discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end());
// DecisionTree for P'(X|M, Z) for all mode sequences M
const AlgebraicDecisionTree<Key> probPrimeTree =
graph.continuousProbPrimes(discrete_keys, continuousBayesTree);
// Add the model selection factor P(M|Z)
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
}
return discreteGraph;
}
/* ************************************************************************* */
std::pair<HybridBayesTree::shared_ptr, HybridGaussianFactorGraph::shared_ptr>
HybridJunctionTree::eliminate(const Eliminate& function) const {
Ordering continuous_ordering = etree_.continuousOrdering();
Ordering discrete_ordering = etree_.discreteOrdering();
FactorGraphType graph = etree_.graph();
// Eliminate continuous
BayesTreeType::shared_ptr continuousBayesTree;
FactorGraphType::shared_ptr discreteGraph;
std::tie(continuousBayesTree, discreteGraph) =
this->eliminateContinuous(function, graph, continuous_ordering);
FactorGraphType::shared_ptr updatedDiscreteGraph =
this->addProbPrimes(graph, continuousBayesTree, discreteGraph,
continuous_ordering, discrete_ordering);
// Eliminate discrete variables to get the discrete bayes tree.
return this->eliminateDiscrete(function, continuousBayesTree,
updatedDiscreteGraph, discrete_ordering);
}
} // namespace gtsam } // namespace gtsam

View File

@ -51,10 +51,15 @@ class HybridEliminationTree;
*/ */
class GTSAM_EXPORT HybridJunctionTree class GTSAM_EXPORT HybridJunctionTree
: public JunctionTree<HybridBayesTree, HybridGaussianFactorGraph> { : public JunctionTree<HybridBayesTree, HybridGaussianFactorGraph> {
/// Record the elimination tree for use in hybrid elimination.
HybridEliminationTree etree_;
/// Store the provided variable index.
VariableIndex variable_index_;
public: public:
typedef JunctionTree<HybridBayesTree, HybridGaussianFactorGraph> typedef JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>
Base; ///< Base class Base; ///< Base class
typedef HybridJunctionTree This; ///< This class typedef HybridJunctionTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/** /**
@ -67,6 +72,66 @@ class GTSAM_EXPORT HybridJunctionTree
* @return The elimination tree * @return The elimination tree
*/ */
HybridJunctionTree(const HybridEliminationTree& eliminationTree); HybridJunctionTree(const HybridEliminationTree& eliminationTree);
/**
* @brief
*
* @param function
* @param graph
* @param continuous_ordering
* @return std::pair<boost::shared_ptr<HybridBayesTree>,
* boost::shared_ptr<HybridGaussianFactorGraph>>
*/
std::pair<boost::shared_ptr<HybridBayesTree>,
boost::shared_ptr<HybridGaussianFactorGraph>>
eliminateContinuous(const Eliminate& function,
const HybridGaussianFactorGraph& graph,
const Ordering& continuous_ordering) const;
/**
* @brief
*
* @param function
* @param continuousBayesTree
* @param discreteGraph
* @param discrete_ordering
* @return std::pair<boost::shared_ptr<HybridBayesTree>,
* boost::shared_ptr<HybridGaussianFactorGraph>>
*/
std::pair<boost::shared_ptr<HybridBayesTree>,
boost::shared_ptr<HybridGaussianFactorGraph>>
eliminateDiscrete(const Eliminate& function,
const HybridBayesTree::shared_ptr& continuousBayesTree,
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
const Ordering& discrete_ordering) const;
/**
* @brief
*
* @param graph
* @param continuousBayesTree
* @param discreteGraph
* @param continuous_ordering
* @param discrete_ordering
* @return boost::shared_ptr<HybridGaussianFactorGraph>
*/
boost::shared_ptr<HybridGaussianFactorGraph> addProbPrimes(
const HybridGaussianFactorGraph& graph,
const HybridBayesTree::shared_ptr& continuousBayesTree,
const HybridGaussianFactorGraph::shared_ptr& discreteGraph,
const Ordering& continuous_ordering,
const Ordering& discrete_ordering) const;
/**
* @brief
*
* @param function
* @return std::pair<boost::shared_ptr<HybridBayesTree>,
* boost::shared_ptr<HybridGaussianFactorGraph>>
*/
std::pair<boost::shared_ptr<HybridBayesTree>,
boost::shared_ptr<HybridGaussianFactorGraph>>
eliminate(const Eliminate& function) const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -182,8 +182,9 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())})); boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())}));
hfg.add(DecisionTreeFactor(m1, {2, 8})); hfg.add(DecisionTreeFactor(m1, {2, 8}));
//TODO(Varun) Adding extra discrete variable not connected to continuous variable throws segfault // TODO(Varun) Adding extra discrete variable not connected to continuous
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); // variable throws segfault
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
HybridBayesTree::shared_ptr result = HybridBayesTree::shared_ptr result =
hfg.eliminateMultifrontal(hfg.getHybridOrdering()); hfg.eliminateMultifrontal(hfg.getHybridOrdering());
@ -276,7 +277,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full); std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full);
// 9 cliques in the bayes tree and 0 remaining variables to eliminate. // 9 cliques in the bayes tree and 0 remaining variables to eliminate.
EXPECT_LONGS_EQUAL(9, hbt->size()); EXPECT_LONGS_EQUAL(7, hbt->size());
EXPECT_LONGS_EQUAL(0, remaining->size()); EXPECT_LONGS_EQUAL(0, remaining->size());
/* /*