Use TableFactor in hybrid elimination
parent
baf25de684
commit
9c88e3ed96
|
|
@ -234,7 +234,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
*/
|
*/
|
||||||
std::function<GaussianConditional::shared_ptr(
|
std::function<GaussianConditional::shared_ptr(
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
GaussianMixture::prunerFunc(const TableFactor &discreteProbs) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the gaussian mixture.
|
// and the gaussian mixture.
|
||||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||||
|
|
@ -285,9 +285,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
|
void GaussianMixture::prune(const TableFactor &discreteProbs) {
|
||||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
|
||||||
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
|
||||||
// Functional which loops over all assignments and create a set of
|
// Functional which loops over all assignments and create a set of
|
||||||
// GaussianConditionals
|
// GaussianConditionals
|
||||||
auto pruner = prunerFunc(discreteProbs);
|
auto pruner = prunerFunc(discreteProbs);
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
|
|
@ -80,7 +81,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*/
|
*/
|
||||||
std::function<GaussianConditional::shared_ptr(
|
std::function<GaussianConditional::shared_ptr(
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
prunerFunc(const DecisionTreeFactor &discreteProbs);
|
prunerFunc(const TableFactor &discreteProbs);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
|
@ -238,7 +239,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*
|
*
|
||||||
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||||
*/
|
*/
|
||||||
void prune(const DecisionTreeFactor &discreteProbs);
|
void prune(const TableFactor &discreteProbs);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
* @return std::function<double(const Assignment<Key> &, double)>
|
* @return std::function<double(const Assignment<Key> &, double)>
|
||||||
*/
|
*/
|
||||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
const DecisionTreeFactor &prunedDiscreteProbs,
|
const TableFactor &prunedDiscreteProbs,
|
||||||
const HybridConditional &conditional) {
|
const HybridConditional &conditional) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the Gaussian mixture.
|
// and the Gaussian mixture.
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
|
||||||
|
|
@ -175,14 +175,15 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
||||||
|
|
||||||
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
// TODO(Varun)
|
||||||
discreteProbs->root_ = prunedDiscreteProbs.root_;
|
// TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
||||||
|
// discreteProbs->root_ = prunedDiscreteProbs.root_;
|
||||||
|
|
||||||
/// Helper struct for pruning the hybrid bayes tree.
|
/// Helper struct for pruning the hybrid bayes tree.
|
||||||
struct HybridPrunerData {
|
struct HybridPrunerData {
|
||||||
/// The discrete decision tree after pruning.
|
/// The discrete decision tree after pruning.
|
||||||
DecisionTreeFactor prunedDiscreteProbs;
|
TableFactor prunedDiscreteProbs;
|
||||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
|
HybridPrunerData(const TableFactor& prunedDiscreteProbs,
|
||||||
const HybridBayesTree::sharedNode& parentClique)
|
const HybridBayesTree::sharedNode& parentClique)
|
||||||
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
||||||
|
|
||||||
|
|
@ -210,15 +211,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
HybridPrunerData rootData(prunedDiscreteProbs, 0);
|
// TODO(Varun)
|
||||||
{
|
// HybridPrunerData rootData(prunedDiscreteProbs, 0);
|
||||||
treeTraversal::no_op visitorPost;
|
// {
|
||||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
// treeTraversal::no_op visitorPost;
|
||||||
TbbOpenMPMixedScope threadLimiter;
|
// // Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||||
treeTraversal::DepthFirstForestParallel(
|
// TbbOpenMPMixedScope threadLimiter;
|
||||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
// treeTraversal::DepthFirstForestParallel(
|
||||||
visitorPost);
|
// *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||||
}
|
// visitorPost);
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue