add methods to perform correct elimination
parent
610a535b30
commit
1815433cbb
|
|
@ -229,6 +229,12 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
/* ************************************************************************* */
|
||||
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn = this->choose(assignment);
|
||||
|
||||
// Check if there exists a nullptr in the GaussianBayesNet
|
||||
// If yes, return an empty VectorValues
|
||||
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
|
||||
return VectorValues();
|
||||
}
|
||||
return gbn.optimize();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -492,6 +492,75 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
|||
return prob_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTree<Key, VectorValues::shared_ptr>
|
||||
HybridGaussianFactorGraph::continuousDelta(
|
||||
const DiscreteKeys &discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
|
||||
const std::vector<DiscreteValues> &assignments) const {
|
||||
// Create a decision tree of all the different VectorValues
|
||||
std::vector<VectorValues::shared_ptr> vector_values;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
VectorValues values = continuousBayesNet->optimize(assignment);
|
||||
vector_values.push_back(boost::make_shared<VectorValues>(values));
|
||||
}
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
||||
vector_values);
|
||||
|
||||
return delta_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
||||
const DiscreteKeys &discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
|
||||
const std::vector<DiscreteValues> &assignments) const {
|
||||
// Create a decision tree of all the different VectorValues
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
|
||||
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);
|
||||
|
||||
// Get the probPrime tree with the correct leaf probabilities
|
||||
std::vector<double> probPrimes;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
VectorValues delta = *delta_tree(assignment);
|
||||
|
||||
// If VectorValues is empty, it means this is a pruned branch.
|
||||
// Set thr probPrime to 0.0.
|
||||
if (delta.size() == 0) {
|
||||
probPrimes.push_back(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
double error = 0.0;
|
||||
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
auto factor = factors_.at(idx);
|
||||
|
||||
if (factor->isHybrid()) {
|
||||
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
error += c->asMixture()->error(delta, assignment);
|
||||
}
|
||||
if (auto f =
|
||||
boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||
error += f->error(delta, assignment);
|
||||
}
|
||||
|
||||
} else if (factor->isContinuous()) {
|
||||
if (auto f =
|
||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
error += f->inner()->error(delta);
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
error += cg->asGaussian()->error(delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
probPrimes.push_back(exp(-error));
|
||||
}
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||
return probPrimeTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateHybridSequential() const {
|
||||
|
|
@ -502,52 +571,29 @@ HybridGaussianFactorGraph::eliminateHybridSequential() const {
|
|||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) =
|
||||
BaseEliminateable::eliminatePartialSequential(
|
||||
continuous_ordering, EliminationTraitsType::DefaultEliminate);
|
||||
BaseEliminateable::eliminatePartialSequential(continuous_ordering);
|
||||
|
||||
// Get the last continuous conditional which will have all the discrete keys
|
||||
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
|
||||
// Get all the discrete assignments
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(discrete_keys);
|
||||
|
||||
// Save a copy of the original discrete key ordering
|
||||
DiscreteKeys orig_discrete_keys(discrete_keys);
|
||||
// Reverse discrete keys order for correct tree construction
|
||||
std::reverse(discrete_keys.begin(), discrete_keys.end());
|
||||
|
||||
// Create a decision tree of all the different VectorValues
|
||||
std::vector<VectorValues::shared_ptr> vector_values;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
VectorValues values = bayesNet->optimize(assignment);
|
||||
vector_values.push_back(boost::make_shared<VectorValues>(values));
|
||||
}
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
||||
vector_values);
|
||||
AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
continuousProbPrimes(discrete_keys, bayesNet, assignments);
|
||||
|
||||
// Get the probPrime tree with the correct leaf probabilities
|
||||
std::vector<double> probPrimes;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
double error = 0.0;
|
||||
VectorValues delta = *delta_tree(assignment);
|
||||
for (auto factor : *this) {
|
||||
if (factor->isHybrid()) {
|
||||
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
|
||||
error += f->error(delta, assignment);
|
||||
|
||||
} else if (factor->isContinuous()) {
|
||||
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
|
||||
error += f->inner()->error(delta);
|
||||
}
|
||||
}
|
||||
probPrimes.push_back(exp(-error));
|
||||
}
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
discreteGraph->add(DecisionTreeFactor(orig_discrete_keys, probPrimeTree));
|
||||
|
||||
// Perform discrete elimination
|
||||
HybridBayesNet::shared_ptr discreteBayesNet =
|
||||
discreteGraph->eliminateSequential(
|
||||
discrete_ordering, EliminationTraitsType::DefaultEliminate);
|
||||
discreteGraph->eliminateSequential(discrete_ordering);
|
||||
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return bayesNet;
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
|
@ -190,6 +191,44 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the VectorValues solution for the continuous variables for
|
||||
* each mode.
|
||||
*
|
||||
* @param discrete_keys The discrete keys which form all the modes.
|
||||
* @param continuousBayesNet The Bayes Net representing the continuous
|
||||
* eliminated variables.
|
||||
* @param assignments List of all discrete assignments to create the final
|
||||
* decision tree.
|
||||
* @return DecisionTree<Key, VectorValues::shared_ptr>
|
||||
*/
|
||||
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta(
|
||||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
|
||||
const std::vector<DiscreteValues>& assignments) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the unnormalized probabilities of the continuous variables
|
||||
* for each of the modes.
|
||||
*
|
||||
* @param discrete_keys The discrete keys which form all the modes.
|
||||
* @param continuousBayesNet The Bayes Net representing the continuous
|
||||
* eliminated variables.
|
||||
* @param assignments List of all discrete assignments to create the final
|
||||
* decision tree.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> continuousProbPrimes(
|
||||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
|
||||
const std::vector<DiscreteValues>& assignments) const;
|
||||
|
||||
/**
|
||||
* @brief Custom elimination function which computes the correct
|
||||
* continuous probabilities.
|
||||
*
|
||||
* @return boost::shared_ptr<BayesNetType>
|
||||
*/
|
||||
boost::shared_ptr<BayesNetType> eliminateHybridSequential() const;
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue