make continuousProbPrimes and continuousDeltas as templates

release/4.3a0
Varun Agrawal 2022-11-12 23:07:34 -05:00
parent 6e6bbfff4c
commit 5e2cdfdd3b
2 changed files with 69 additions and 124 deletions

View File

@ -528,118 +528,6 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
return prob_tree; return prob_tree;
} }
/* ************************************************************************ */
DecisionTree<Key, VectorValues::shared_ptr>
HybridGaussianFactorGraph::continuousDelta(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = continuousBayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);
return delta_tree;
}
/* ************************************************************************ */
DecisionTree<Key, VectorValues::shared_ptr>
HybridGaussianFactorGraph::continuousDelta(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesTreeType> &continuousBayesTree,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = continuousBayesTree->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);
return delta_tree;
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &orig_discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(orig_discrete_keys);
// Save a copy of the original discrete key ordering
DiscreteKeys discrete_keys(orig_discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());
// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);
// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
VectorValues delta = *delta_tree(assignment);
// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}
// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &orig_discrete_keys,
const boost::shared_ptr<BayesTreeType> &continuousBayesTree) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(orig_discrete_keys);
// Save a copy of the original discrete key ordering
DiscreteKeys discrete_keys(orig_discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());
// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesTree, assignments);
// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
VectorValues delta = *delta_tree(assignment);
// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}
// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<Ordering, Ordering> std::pair<Ordering, Ordering>
HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(

View File

@ -220,44 +220,89 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute the VectorValues solution for the continuous variables for * @brief Compute the VectorValues solution for the continuous variables for
* each mode. * each mode.
* *
* @tparam BAYES Template on the type of Bayes graph, either a bayes net or a
* bayes tree.
* @param discrete_keys The discrete keys which form all the modes. * @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous * @param continuousBayesNet The Bayes Net/Tree representing the continuous
* eliminated variables. * eliminated variables.
* @param assignments List of all discrete assignments to create the final * @param assignments List of all discrete assignments to create the final
* decision tree. * decision tree.
* @return DecisionTree<Key, VectorValues::shared_ptr> * @return DecisionTree<Key, VectorValues::shared_ptr>
*/ */
template <typename BAYES>
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta( DecisionTree<Key, VectorValues::shared_ptr> continuousDelta(
const DiscreteKeys& discrete_keys, const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet, const boost::shared_ptr<BAYES>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const; const std::vector<DiscreteValues>& assignments) const {
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta( // Create a decision tree of all the different VectorValues
const DiscreteKeys& discrete_keys, std::vector<VectorValues::shared_ptr> vector_values;
const boost::shared_ptr<BayesTreeType>& continuousBayesTree, for (const DiscreteValues& assignment : assignments) {
const std::vector<DiscreteValues>& assignments) const; VectorValues values = continuousBayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);
return delta_tree;
}
/** /**
* @brief Compute the unnormalized probabilities of the continuous variables * @brief Compute the unnormalized probabilities of the continuous variables
* for each of the modes. * for each of the modes.
* *
* @tparam BAYES Template on the type of Bayes graph, either a bayes net or a
* bayes tree.
* @param discrete_keys The discrete keys which form all the modes. * @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous * @param continuousBayesNet The Bayes Net representing the continuous
* eliminated variables. * eliminated variables.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
template <typename BAYES>
AlgebraicDecisionTree<Key> continuousProbPrimes( AlgebraicDecisionTree<Key> continuousProbPrimes(
const DiscreteKeys& discrete_keys, const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const; const boost::shared_ptr<BAYES>& continuousBayesNet) const {
AlgebraicDecisionTree<Key> continuousProbPrimes( // Generate all possible assignments.
const DiscreteKeys& discrete_keys, const std::vector<DiscreteValues> assignments =
const boost::shared_ptr<BayesTreeType>& continuousBayesTree) const; DiscreteValues::CartesianProduct(discrete_keys);
// Save a copy of the original discrete key ordering
DiscreteKeys reversed_discrete_keys(discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(reversed_discrete_keys.begin(), reversed_discrete_keys.end());
// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(reversed_discrete_keys, continuousBayesNet,
assignments);
// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues& assignment : assignments) {
VectorValues delta = *delta_tree(assignment);
// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}
// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(reversed_discrete_keys,
probPrimes);
return probPrimeTree;
}
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering( std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
const Ordering& ordering) const; const Ordering& ordering) const;
/** /**
* @brief Custom elimination function which computes the correct * @brief Custom elimination function which computes the correct
* continuous probabilities. * continuous probabilities. Returns a bayes net.
* *
* @param continuous Optional ordering for all continuous variables. * @param continuous Optional ordering for all continuous variables.
* @param discrete Optional ordering for all discrete variables. * @param discrete Optional ordering for all discrete variables.
@ -269,27 +314,39 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
const Eliminate& function = EliminationTraitsType::DefaultEliminate, const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/// Sequential elimination overload for hybrid
boost::shared_ptr<BayesNetType> eliminateSequential( boost::shared_ptr<BayesNetType> eliminateSequential(
OptionalOrderingType orderingType = boost::none, OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate, const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/// Sequential elimination overload for hybrid
boost::shared_ptr<BayesNetType> eliminateSequential( boost::shared_ptr<BayesNetType> eliminateSequential(
const Ordering& ordering, const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate, const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/**
* @brief Custom elimination function which computes the correct
* continuous probabilities. Returns a bayes tree.
*
* @param continuous Optional ordering for all continuous variables.
* @param discrete Optional ordering for all discrete variables.
* @return boost::shared_ptr<BayesTreeType>
*/
boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal( boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal(
const boost::optional<Ordering> continuous = boost::none, const boost::optional<Ordering> continuous = boost::none,
const boost::optional<Ordering> discrete = boost::none, const boost::optional<Ordering> discrete = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate, const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/// Multifrontal elimination overload for hybrid
boost::shared_ptr<BayesTreeType> eliminateMultifrontal( boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
OptionalOrderingType orderingType = boost::none, OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate, const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/// Multifrontal elimination overload for hybrid
boost::shared_ptr<BayesTreeType> eliminateMultifrontal( boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
const Ordering& ordering, const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate, const Eliminate& function = EliminationTraitsType::DefaultEliminate,