Use TableFactor in hybrid elimination

release/4.3a0
Varun Agrawal 2023-07-13 16:39:55 -04:00
parent baf25de684
commit 9c88e3ed96
5 changed files with 21 additions and 21 deletions

View File

@ -234,7 +234,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { GaussianMixture::prunerFunc(const TableFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the gaussian mixture.
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
@ -285,9 +285,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { void GaussianMixture::prune(const TableFactor &discreteProbs) {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(discreteProbs); auto pruner = prunerFunc(discreteProbs);

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
@ -80,7 +81,7 @@ class GTSAM_EXPORT GaussianMixture
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &discreteProbs); prunerFunc(const TableFactor &discreteProbs);
public: public:
/// @name Constructors /// @name Constructors
@ -238,7 +239,7 @@ class GTSAM_EXPORT GaussianMixture
* *
* @param discreteProbs A pruned set of probabilities for the discrete keys. * @param discreteProbs A pruned set of probabilities for the discrete keys.
*/ */
void prune(const DecisionTreeFactor &discreteProbs); void prune(const TableFactor &discreteProbs);
/** /**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while

View File

@ -64,7 +64,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
* @return std::function<double(const Assignment<Key> &, double)> * @return std::function<double(const Assignment<Key> &, double)>
*/ */
std::function<double(const Assignment<Key> &, double)> prunerFunc( std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDiscreteProbs, const TableFactor &prunedDiscreteProbs,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the Gaussian mixture. // and the Gaussian mixture.

View File

@ -17,7 +17,6 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>

View File

@ -175,14 +175,15 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); // TODO(Varun)
discreteProbs->root_ = prunedDiscreteProbs.root_; // TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
// discreteProbs->root_ = prunedDiscreteProbs.root_;
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {
/// The discrete decision tree after pruning. /// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteProbs; TableFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, HybridPrunerData(const TableFactor& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique) const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteProbs(prunedDiscreteProbs) {} : prunedDiscreteProbs(prunedDiscreteProbs) {}
@ -210,15 +211,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
} }
}; };
HybridPrunerData rootData(prunedDiscreteProbs, 0); // TODO(Varun)
{ // HybridPrunerData rootData(prunedDiscreteProbs, 0);
treeTraversal::no_op visitorPost; // {
// Limits OpenMP threads since we're mixing TBB and OpenMP // treeTraversal::no_op visitorPost;
TbbOpenMPMixedScope threadLimiter; // // Limits OpenMP threads since we're mixing TBB and OpenMP
treeTraversal::DepthFirstForestParallel( // TbbOpenMPMixedScope threadLimiter;
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, // treeTraversal::DepthFirstForestParallel(
visitorPost); // *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
} // visitorPost);
// }
} }
} // namespace gtsam } // namespace gtsam