From f6d42d0ee07662bf7baf1055239c73147f4e8705 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Jul 2023 16:40:27 -0400 Subject: [PATCH 01/17] small improvements --- gtsam/discrete/DecisionTreeFactor.h | 3 +++ gtsam/hybrid/HybridBayesNet.cpp | 10 +++++----- gtsam/hybrid/HybridSmoother.cpp | 3 ++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 6cce6e5d4..e90a2f96f 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,6 +147,9 @@ namespace gtsam { /// @name Advanced Interface /// @{ + /// Inherit all the `apply` methods from AlgebraicDecisionTree + using ADT::apply; + /** * Apply binary operator (*this) "op" f * @param f the second argument for op diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 266e02b0d..9df58814b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -155,18 +155,18 @@ void HybridBayesNet::updateDiscreteConditionals( auto discrete = conditional->asDiscrete(); // Convert pointer from conditional to factor - auto discreteTree = - std::dynamic_pointer_cast(discrete); + auto discreteFactor = + std::dynamic_pointer_cast(discrete); // Apply prunerFunc to the underlying AlgebraicDecisionTree - DecisionTreeFactor::ADT prunedDiscreteTree = - discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional)); + DecisionTreeFactor::ADT prunedDiscreteFactor = + discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional)); gttic_(HybridBayesNet_MakeConditional); // Create the new (hybrid) conditional KeyVector frontals(discrete->frontals().begin(), discrete->frontals().end()); auto prunedDiscrete = std::make_shared( - frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); + frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor); conditional = std::make_shared(prunedDiscrete); gttoc_(HybridBayesNet_MakeConditional); diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 56c62cf19..27d3f70fc 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, addConditionals(graph, hybridBayesNet_, ordering); // Eliminate. - auto bayesNetFragment = graph.eliminateSequential(ordering); + HybridBayesNet::shared_ptr bayesNetFragment = + graph.eliminateSequential(ordering); /// Prune if (maxNrLeaves) { From 53d00864bbf41d352c8f0bee765691fbcf7002c5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 5 Jul 2023 11:04:56 -0400 Subject: [PATCH 02/17] small cleanup --- gtsam/hybrid/HybridBayesNet.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 9df58814b..0db1b0c48 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -146,8 +146,7 @@ std::function &, double)> prunerFunc( /* ************************************************************************* */ void HybridBayesNet::updateDiscreteConditionals( const DecisionTreeFactor &prunedDiscreteProbs) { - KeyVector prunedTreeKeys = prunedDiscreteProbs.keys(); - + //TODO(Varun) Should prune the joint conditional, maybe during elimination? // Loop with index since we need it later. for (size_t i = 0; i < this->size(); i++) { HybridConditional::shared_ptr conditional = this->at(i); From 6d69ca16dacc6269b55a85b55b0d8def40e62195 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 5 Jul 2023 11:09:14 -0400 Subject: [PATCH 03/17] add separate Hybrid ISAM and Smoother tests --- gtsam/hybrid/tests/testHybridEstimation.cpp | 55 +++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index b5f5244fa..b8edc39d8 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) { EXPECT(assert_equal(expected_continuous, result)); } +/****************************************************************************/ +// Test approximate inference with an additional pruning step. +TEST(HybridEstimation, ISAM) { + size_t K = 15; + std::vector measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, + 7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; + // Ground truth discrete seq + std::vector discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); + HybridNonlinearISAM isam; + HybridNonlinearFactorGraph graph; + Values initial; + + // gttic_(Estimation); + + // Add the X(0) prior + graph.push_back(switching.nonlinearFactorGraph.at(0)); + initial.insert(X(0), switching.linearizationPoint.at(X(0))); + + HybridGaussianFactorGraph linearized; + + for (size_t k = 1; k < K; k++) { + // Motion Model + graph.push_back(switching.nonlinearFactorGraph.at(k)); + // Measurement + graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1)); + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + isam.update(graph, initial, 3); + // isam.bayesTree().print("\n\n"); + + graph.resize(0); + initial.clear(); + } + + Values result = isam.estimate(); + DiscreteValues assignment = isam.assignment(); + + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, assignment)); + + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + /** * @brief A function to get a specific 1D robot motion problem as a linearized * factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous From f751a5bfcf085f429cf48ce51f854587d1b28cad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 6 Jul 2023 16:37:22 -0400 Subject: [PATCH 04/17] overload apply method in DecisionTreeFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 8 ++++++++ gtsam/discrete/DecisionTreeFactor.h | 7 +++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ff18268b1..56f1659dc 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -101,6 +101,14 @@ namespace gtsam { return DecisionTreeFactor(keys, result); } + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + // apply operand + ADT result = ADT::apply(op); + // Make a new factor + return DecisionTreeFactor(discreteKeys(), result); + } + /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( size_t nrFrontals, ADT::Binary op) const { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index e90a2f96f..e92c82b77 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,8 +147,11 @@ namespace gtsam { /// @name Advanced Interface /// @{ - /// Inherit all the `apply` methods from AlgebraicDecisionTree - using ADT::apply; + /** + * Apply unary operator (*this) "op" f + * @param op a unary operator that operates on AlgebraicDecisionTree + */ + DecisionTreeFactor apply(ADT::UnaryAssignment op) const; /** * Apply binary operator (*this) "op" f From 4e902fc8a71f2124b066d823474eeab63fec019c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 6 Jul 2023 22:22:51 -0400 Subject: [PATCH 05/17] fix continuousKeySet method --- gtsam/hybrid/HybridFactorGraph.cpp | 2 ++ gtsam/hybrid/tests/testHybridFactorGraph.cpp | 28 +++++++++++++++++++ .../tests/testHybridGaussianFactorGraph.cpp | 2 +- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index d96a890f4..235ffc87f 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -67,6 +67,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const { for (const Key& key : p->continuousKeys()) { keys.insert(key); } + } else if (auto p = std::dynamic_pointer_cast(factor)) { + keys.insert(p->keys().begin(), p->keys().end()); } } return keys; diff --git a/gtsam/hybrid/tests/testHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridFactorGraph.cpp index f5b4ec0b1..33c0761eb 100644 --- a/gtsam/hybrid/tests/testHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridFactorGraph.cpp @@ -18,7 +18,9 @@ #include #include #include +#include #include +#include #include using namespace std; @@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) { HybridFactorGraph fg; } +/* ************************************************************************* */ +// Test if methods to get keys work as expected. +TEST(HybridFactorGraph, Keys) { + HybridGaussianFactorGraph hfg; + + // Add prior on x0 + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + + // Add factor between x0 and x1 + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + // Add a gaussian mixture factor ϕ(x1, c1) + DiscreteKey m1(M(1), 2); + DecisionTree dt( + M(1), std::make_shared(X(1), I_3x3, Z_3x1), + std::make_shared(X(1), I_3x3, Vector3::Ones())); + hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt)); + + KeySet expected_continuous{X(0), X(1)}; + EXPECT( + assert_container_equality(expected_continuous, hfg.continuousKeySet())); + + KeySet expected_discrete{M(1)}; + EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet())); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 8276264ae..1da897103 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { // Test resulting posterior Bayes net has correct size: EXPECT_LONGS_EQUAL(8, posterior->size()); - // TODO(dellaert): this test fails - no idea why. + // Ratio test EXPECT(ratioTest(bn, measurements, *posterior)); } From 4e13fb717bcc1a49358f25acd45a987246d011c5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 6 Jul 2023 23:07:38 -0400 Subject: [PATCH 06/17] simplify HybridEliminate --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 86 +++++++++++----------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2b23ed4db..d2ea3d5ef 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -348,64 +348,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // When the number of assignments is large we may encounter stack overflows. // However this is also the case with iSAM2, so no pressure :) - // PREPROCESS: Identify the nature of the current elimination - - // TODO(dellaert): just check the factors: + // Check the factors: // 1. if all factors are discrete, then we can do discrete elimination: // 2. if all factors are continuous, then we can do continuous elimination: // 3. if not, we do hybrid elimination: - // First, identify the separator keys, i.e. all keys that are not frontal. - KeySet separatorKeys; + bool only_discrete = true, only_continuous = true; for (auto &&factor : factors) { - separatorKeys.insert(factor->begin(), factor->end()); - } - // remove frontals from separator - for (auto &k : frontalKeys) { - separatorKeys.erase(k); - } - - // Build a map from keys to DiscreteKeys - auto mapFromKeyToDiscreteKey = factors.discreteKeyMap(); - - // Fill in discrete frontals and continuous frontals. - std::set discreteFrontals; - KeySet continuousFrontals; - for (auto &k : frontalKeys) { - if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { - discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k)); - } else { - continuousFrontals.insert(k); + if (auto hybrid_factor = std::dynamic_pointer_cast(factor)) { + if (hybrid_factor->isDiscrete()) { + only_continuous = false; + } else if (hybrid_factor->isContinuous()) { + only_discrete = false; + } else if (hybrid_factor->isHybrid()) { + only_continuous = false; + only_discrete = false; + } + } else if (auto cont_factor = + std::dynamic_pointer_cast(factor)) { + only_discrete = false; + } else if (auto discrete_factor = + std::dynamic_pointer_cast(factor)) { + only_continuous = false; } } - // Fill in discrete discrete separator keys and continuous separator keys. - std::set discreteSeparatorSet; - KeyVector continuousSeparator; - for (auto &k : separatorKeys) { - if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { - discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); - } else { - continuousSeparator.push_back(k); - } - } - - // Check if we have any continuous keys: - const bool discrete_only = - continuousFrontals.empty() && continuousSeparator.empty(); - // NOTE: We should really defer the product here because of pruning - if (discrete_only) { + if (only_discrete) { // Case 1: we are only dealing with discrete return discreteElimination(factors, frontalKeys); - } else if (mapFromKeyToDiscreteKey.empty()) { + } else if (only_continuous) { // Case 2: we are only dealing with continuous return continuousElimination(factors, frontalKeys); } else { // Case 3: We are now in the hybrid land! + KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); + + // Find all the keys in the set of continuous keys + // which are not in the frontal keys. This is our continuous separator. + KeyVector continuousSeparator; + auto continuousKeySet = factors.continuousKeySet(); + std::set_difference( + continuousKeySet.begin(), continuousKeySet.end(), + frontalKeysSet.begin(), frontalKeysSet.end(), + std::inserter(continuousSeparator, continuousSeparator.begin())); + + // Similarly for the discrete separator. + KeySet discreteSeparatorSet; + std::set discreteSeparator; + auto discreteKeySet = factors.discreteKeySet(); + std::set_difference( + discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(), + frontalKeysSet.end(), + std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin())); + // Convert from set of keys to set of DiscreteKeys + auto discreteKeyMap = factors.discreteKeyMap(); + for (auto key : discreteSeparatorSet) { + discreteSeparator.insert(discreteKeyMap.at(key)); + } + return hybridElimination(factors, frontalKeys, continuousSeparator, - discreteSeparatorSet); + discreteSeparator); } } From f6b1872b1335ab22045aef998832b3aff8020558 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 8 Jul 2023 13:09:35 -0400 Subject: [PATCH 07/17] initial changes --- gtsam/hybrid/HybridFactorGraph.cpp | 3 +-- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 16 +++++++++------- gtsam/hybrid/HybridGaussianFactorGraph.h | 1 + gtsam/hybrid/HybridJunctionTree.cpp | 4 ++-- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 3 ++- gtsam/hybrid/tests/testHybridBayesTree.cpp | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 235ffc87f..f7b96f694 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -17,7 +17,6 @@ * @date January, 2023 */ -#include #include namespace gtsam { @@ -26,7 +25,7 @@ namespace gtsam { std::set HybridFactorGraph::discreteKeys() const { std::set keys; for (auto& factor : factors_) { - if (auto p = std::dynamic_pointer_cast(factor)) { + if (auto p = std::dynamic_pointer_cast(factor)) { for (const DiscreteKey& key : p->discreteKeys()) { keys.insert(key); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d2ea3d5ef..fb4b69aaf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -48,8 +48,6 @@ #include #include -// #define HYBRID_TIMING - namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { // TODO(dellaert): in C++20, we can use std::visit. continue; } - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // Don't do anything for discrete-only factors // since we want to eliminate continuous values only. continue; @@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, DiscreteFactorGraph dfg; for (auto &f : factors) { - if (auto dtf = dynamic_pointer_cast(f)) { - dfg.push_back(dtf); + if (auto df = dynamic_pointer_cast(f)) { + dfg.push_back(df); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. // TODO(dellaert): is this correct? If so explain here. @@ -262,9 +260,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; DecisionTree probabilities(eliminationResults, probability); + + auto dtf = + std::make_shared(discreteSeparator, probabilities); + return { std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, probabilities)}; + std::make_shared(discreteSeparator, dtf->probabilities())}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant. @@ -433,7 +435,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // Add the gaussian factor error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If factor at `idx` is discrete-only, we skip. continue; } else { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 421e69aa0..b3f159150 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -40,6 +40,7 @@ class HybridEliminationTree; class HybridBayesTree; class HybridJunctionTree; class DecisionTreeFactor; +class TableFactor; class JacobianFactor; class HybridValues; diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 6f2898bf1..22d3c7dd2 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -66,7 +66,7 @@ struct HybridConstructorTraversalData { for (auto& k : hf->discreteKeys()) { data.discreteKeys.insert(k.first); } - } else if (auto hf = std::dynamic_pointer_cast(f)) { + } else if (auto hf = std::dynamic_pointer_cast(f)) { for (auto& k : hf->discreteKeys()) { data.discreteKeys.insert(k.first); } @@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree( Data rootData(0); rootData.junctionTreeNode = std::make_shared(); // Make a dummy node to gather - // the junction tree roots + // the junction tree roots treeTraversal::DepthFirstForest(eliminationTree, rootData, Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPost); diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 260f534e3..2459e4ec9 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -17,6 +17,7 @@ */ #include +#include #include #include #include @@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( } else if (auto nlf = dynamic_pointer_cast(f)) { const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); linearFG->push_back(gf); - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If discrete-only: doesn't need linearization. linearFG->push_back(f); } else if (auto gmf = dynamic_pointer_cast(f)) { diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 578f5d605..81b257c32 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) { DiscreteFactorGraph dfg; for (auto&& f : *remainingFactorGraph) { - auto discreteFactor = dynamic_pointer_cast(f); + auto discreteFactor = dynamic_pointer_cast(f); assert(discreteFactor); dfg.push_back(discreteFactor); } From 2940e69a73479b88f20b7c46474e4eae5e0d9845 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 9 Jul 2023 20:24:24 -0400 Subject: [PATCH 08/17] discreteConditionals returns DiscreteConditional --- gtsam/hybrid/HybridBayesNet.cpp | 18 +++++++----------- gtsam/hybrid/HybridBayesNet.h | 4 ++-- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 0db1b0c48..3b5ab5b80 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -38,21 +38,17 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { } /* ************************************************************************* */ -DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { - AlgebraicDecisionTree discreteProbs; - +DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const { // The canonical decision tree factor which will get // the discrete conditionals added to it. - DecisionTreeFactor discreteProbsFactor; + DiscreteConditional discreteProbs; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - // Convert to a DecisionTreeFactor and add it to the main factor. - DecisionTreeFactor f(*conditional->asDiscrete()); - discreteProbsFactor = discreteProbsFactor * f; + discreteProbs = discreteProbs * (*conditional->asDiscrete()); } } - return std::make_shared(discreteProbsFactor); + return std::make_shared(discreteProbs); } /* ************************************************************************* */ @@ -146,8 +142,8 @@ std::function &, double)> prunerFunc( /* ************************************************************************* */ void HybridBayesNet::updateDiscreteConditionals( const DecisionTreeFactor &prunedDiscreteProbs) { - //TODO(Varun) Should prune the joint conditional, maybe during elimination? - // Loop with index since we need it later. + // TODO(Varun) Should prune the joint conditional, maybe during elimination? + // Loop with index since we need it later. for (size_t i = 0; i < this->size(); i++) { HybridConditional::shared_ptr conditional = this->at(i); if (conditional->isDiscrete()) { @@ -179,7 +175,7 @@ void HybridBayesNet::updateDiscreteConditionals( HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Get the decision tree of only the discrete keys gttic_(HybridBayesNet_PruneDiscreteConditionals); - DecisionTreeFactor::shared_ptr discreteConditionals = + DiscreteConditional::shared_ptr discreteConditionals = this->discreteConditionals(); const DecisionTreeFactor prunedDiscreteProbs = discreteConditionals->prune(maxNrLeaves); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 23fc4d5d3..19e88d754 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -139,9 +139,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /** * @brief Get all the discrete conditionals as a decision tree factor. * - * @return DecisionTreeFactor::shared_ptr + * @return DiscreteConditional::shared_ptr */ - DecisionTreeFactor::shared_ptr discreteConditionals() const; + DiscreteConditional::shared_ptr discreteConditionals() const; /** * @brief Sample from an incomplete BayesNet, given missing variables. From 2f4133fd49ee3367c7ff7684187f84f4eeb49dc2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 10 Jul 2023 19:39:28 -0400 Subject: [PATCH 09/17] Revert "remove nrAssignments from the DecisionTree" This reverts commit 647d3c0744171ffa78f585a39a88bfa4be4d2002. --- gtsam/discrete/DecisionTree-inl.h | 13 +++++++- gtsam/discrete/DecisionTree.h | 36 +++++++++++++++++++++++ gtsam/discrete/tests/testDecisionTree.cpp | 32 ++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9d618dea0..19e3e2887 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -93,7 +93,8 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; + std::cout << s << " Leaf [" << nrAssignments() << "]" + << valueFormatter(constant_) << std::endl; } /** Write graphviz format to stream `os`. */ @@ -827,6 +828,16 @@ namespace gtsam { return total; } + /****************************************************************************/ + template + size_t DecisionTree::nrAssignments() const { + size_t n = 0; + this->visitLeaf([&n](const DecisionTree::Leaf& leaf) { + n += leaf.nrAssignments(); + }); + return n; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ed1908485..bee0ce5c7 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -299,6 +299,42 @@ namespace gtsam { /// Return the number of leaves in the tree. size_t nrLeaves() const; + /** + * @brief This is a convenience function which returns the total number of + * leaf assignments in the decision tree. + * This function is not used for anymajor operations within the discrete + * factor graph framework. + * + * Leaf assignments represent the cardinality of each leaf node, e.g. in a + * binary tree each leaf has 2 assignments. This includes counts removed + * from implicit pruning hence, it will always be >= nrLeaves(). + * + * E.g. we have a decision tree as below, where each node has 2 branches: + * + * Choice(m1) + * 0 Choice(m0) + * 0 0 Leaf 0.0 + * 0 1 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * In the unpruned form, the tree will have 4 assignments, 2 for each key, + * and 4 leaves. + * + * In the pruned form, the number of assignments is still 4 but the number + * of leaves is now 3, as below: + * + * Choice(m1) + * 0 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * @return size_t + */ + size_t nrAssignments() const; + /** * @brief Fold a binary function over the tree, returning accumulator. * diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index d2a94ddc3..395915068 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -531,6 +531,38 @@ TEST(DecisionTree, ApplyWithAssignment) { EXPECT_LONGS_EQUAL(5, count); } +/* ************************************************************************** */ +// Test number of assignments. +TEST(DecisionTree, NrAssignments2) { + using gtsam::symbol_shorthand::M; + + std::vector probs = {0, 0, 1, 2}; + + /* Create the decision tree + Choice(m1) + 0 Leaf 0.000000 + 1 Choice(m0) + 1 0 Leaf 1.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; + DecisionTree dt1(keys, probs); + EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); + + /* Create the DecisionTree + Choice(m1) + 0 Choice(m0) + 0 0 Leaf 0.000000 + 0 1 Leaf 1.000000 + 1 Choice(m0) + 1 0 Leaf 0.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; + DecisionTree dt2(keys2, probs); + EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); +} + /* ************************************************************************* */ int main() { TestResult tr; From ddb36c2e7b20606bd31baeac0b190083f5d9c780 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 10 Jul 2023 19:39:36 -0400 Subject: [PATCH 10/17] Revert "enumerate all assignments for computing probabilities to prune" This reverts commit 8c38e45c83c6bbeb2ef5e7c7d73807b2eed6831f. --- .../tests/testGaussianMixtureFactor.cpp | 4 +- .../tests/testHybridNonlinearFactorGraph.cpp | 38 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 75ba5a059..5207e9372 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) { std::string expected = R"(Hybrid [x1 x2; 1]{ Choice(1) - 0 Leaf : + 0 Leaf [1]: A[x1] = [ 0; 0 @@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) { b = [ 0 0 ] No noise model - 1 Leaf : + 1 Leaf [1]: A[x1] = [ 0; 0 diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index af3a23b94..7bcaf1762 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -492,7 +492,7 @@ factor 0: factor 1: Hybrid [x0 x1; m0]{ Choice(m0) - 0 Leaf : + 0 Leaf [1]: A[x0] = [ -1 ] @@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{ b = [ -1 ] No noise model - 1 Leaf : + 1 Leaf [1]: A[x0] = [ -1 ] @@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{ factor 2: Hybrid [x1 x2; m1]{ Choice(m1) - 0 Leaf : + 0 Leaf [1]: A[x1] = [ -1 ] @@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{ b = [ -1 ] No noise model - 1 Leaf : + 1 Leaf [1]: A[x1] = [ -1 ] @@ -550,16 +550,16 @@ factor 4: b = [ -10 ] No noise model factor 5: P( m0 ): - Leaf 0.5 + Leaf [2] 0.5 factor 6: P( m1 | m0 ): Choice(m1) 0 Choice(m0) - 0 0 Leaf 0.33333333 - 0 1 Leaf 0.6 + 0 0 Leaf [1]0.33333333 + 0 1 Leaf [1] 0.6 1 Choice(m0) - 1 0 Leaf 0.66666667 - 1 1 Leaf 0.4 + 1 0 Leaf [1]0.66666667 + 1 1 Leaf [1] 0.4 )"; EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); @@ -570,13 +570,13 @@ size: 3 conditional 0: Hybrid P( x0 | x1 m0) Discrete Keys = (m0, 2), Choice(m0) - 0 Leaf p(x0 | x1) + 0 Leaf [1] p(x0 | x1) R = [ 10.0499 ] S[x1] = [ -0.0995037 ] d = [ -9.85087 ] No noise model - 1 Leaf p(x0 | x1) + 1 Leaf [1] p(x0 | x1) R = [ 10.0499 ] S[x1] = [ -0.0995037 ] d = [ -9.95037 ] @@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1) Discrete Keys = (m0, 2), (m1, 2), Choice(m1) 0 Choice(m0) - 0 0 Leaf p(x1 | x2) + 0 0 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -9.99901 ] No noise model - 0 1 Leaf p(x1 | x2) + 0 1 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -9.90098 ] No noise model 1 Choice(m0) - 1 0 Leaf p(x1 | x2) + 1 0 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -10.098 ] No noise model - 1 1 Leaf p(x1 | x2) + 1 1 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -10 ] @@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1) Discrete Keys = (m0, 2), (m1, 2), Choice(m1) 0 Choice(m0) - 0 0 Leaf p(x2) + 0 0 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.1489 ] mean: 1 elements x2: -1.0099 No noise model - 0 1 Leaf p(x2) + 0 1 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.1479 ] mean: 1 elements @@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1) No noise model 1 Choice(m0) - 1 0 Leaf p(x2) + 1 0 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.0504 ] mean: 1 elements x2: -1.0001 No noise model - 1 1 Leaf p(x2) + 1 1 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.0494 ] mean: 1 elements From f4adfac4fa4fff940165fa99283c31b7fd18123f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 12 Jul 2023 22:54:39 -0400 Subject: [PATCH 11/17] Undo TableFactor return in Hybrid GFG so we can group the changes together --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index fb4b69aaf..2d4ac83f6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -261,12 +261,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors, DecisionTree probabilities(eliminationResults, probability); - auto dtf = - std::make_shared(discreteSeparator, probabilities); - return { std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, dtf->probabilities())}; + std::make_shared(discreteSeparator, probabilities)}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant. From 6a26ecf971a30d822b586768a508338068c4a56f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 13 Jul 2023 15:31:06 -0400 Subject: [PATCH 12/17] templetize functions in Switching fixture --- gtsam/hybrid/tests/Switching.h | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 5842e1f1a..4b2d3f11b 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -202,31 +202,16 @@ struct Switching { * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). * E.g. if K=4, we want M0, M1 and M2. * - * @param fg The nonlinear factor graph to which the mode chain is added. + * @param fg The factor graph to which the mode chain is added. */ - void addModeChain(HybridNonlinearFactorGraph *fg, + template + void addModeChain(FACTORGRAPH *fg, std::string discrete_transition_prob = "1/2 3/2") { - fg->emplace_shared(modes[0], "1/1"); + fg->template emplace_shared(modes[0], "1/1"); for (size_t k = 0; k < K - 2; k++) { auto parents = {modes[k]}; - fg->emplace_shared(modes[k + 1], parents, - discrete_transition_prob); - } - } - - /** - * @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2). - * E.g. if K=4, we want M0, M1 and M2. - * - * @param fg The gaussian factor graph to which the mode chain is added. - */ - void addModeChain(HybridGaussianFactorGraph *fg, - std::string discrete_transition_prob = "1/2 3/2") { - fg->emplace_shared(modes[0], "1/1"); - for (size_t k = 0; k < K - 2; k++) { - auto parents = {modes[k]}; - fg->emplace_shared(modes[k + 1], parents, - discrete_transition_prob); + fg->template emplace_shared( + modes[k + 1], parents, discrete_transition_prob); } } }; From f7071298c3b63bda597b7e6c9e33740ea62f1381 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 13 Jul 2023 16:06:16 -0400 Subject: [PATCH 13/17] small improvements to comments and code structure --- gtsam/hybrid/HybridBayesNet.cpp | 7 +++---- gtsam/hybrid/HybridSmoother.cpp | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 3b5ab5b80..ff2752bcb 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -39,8 +39,7 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { /* ************************************************************************* */ DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const { - // The canonical decision tree factor which will get - // the discrete conditionals added to it. + // The joint discrete probability. DiscreteConditional discreteProbs; for (auto &&conditional : *this) { @@ -152,7 +151,7 @@ void HybridBayesNet::updateDiscreteConditionals( // Convert pointer from conditional to factor auto discreteFactor = std::dynamic_pointer_cast(discrete); - // Apply prunerFunc to the underlying AlgebraicDecisionTree + // Apply prunerFunc to the underlying conditional DecisionTreeFactor::ADT prunedDiscreteFactor = discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional)); @@ -173,7 +172,7 @@ void HybridBayesNet::updateDiscreteConditionals( /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { - // Get the decision tree of only the discrete keys + // Get the joint distribution of only the discrete keys gttic_(HybridBayesNet_PruneDiscreteConditionals); DiscreteConditional::shared_ptr discreteConditionals = this->discreteConditionals(); diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 27d3f70fc..afa8340d2 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -97,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, HybridGaussianFactorGraph graph(originalGraph); HybridBayesNet hybridBayesNet(originalHybridBayesNet); - // If we are not at the first iteration, means we have conditionals to add. + // If hybridBayesNet is not empty, + // it means we have conditionals to add to the factor graph. if (!hybridBayesNet.empty()) { // We add all relevant conditional mixtures on the last continuous variable // in the previous `hybridBayesNet` to the graph From 3fe2682d93d30e6ad011dd099cf428f680fe7bd2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 14 Jul 2023 17:52:22 -0400 Subject: [PATCH 14/17] prune joint discrete probability which is faster --- gtsam/hybrid/HybridBayesNet.cpp | 85 ++++++++++------------- gtsam/hybrid/HybridBayesNet.h | 13 +--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 26 +++++-- 3 files changed, 58 insertions(+), 66 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ff2752bcb..b4bf61220 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { return Base::equals(bn, tol); } -/* ************************************************************************* */ -DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const { - // The joint discrete probability. - DiscreteConditional discreteProbs; - - for (auto &&conditional : *this) { - if (conditional->isDiscrete()) { - discreteProbs = discreteProbs * (*conditional->asDiscrete()); - } - } - return std::make_shared(discreteProbs); -} - /* ************************************************************************* */ /** * @brief Helper function to get the pruner functional. @@ -139,52 +126,52 @@ std::function &, double)> prunerFunc( } /* ************************************************************************* */ -void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor &prunedDiscreteProbs) { - // TODO(Varun) Should prune the joint conditional, maybe during elimination? - // Loop with index since we need it later. +DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( + size_t maxNrLeaves) { + // Get the joint distribution of only the discrete keys + gttic_(HybridBayesNet_PruneDiscreteConditionals); + // The joint discrete probability. + DiscreteConditional discreteProbs; + + std::vector discrete_factor_idxs; + // Record frontal keys so we can maintain ordering + Ordering discrete_frontals; + for (size_t i = 0; i < this->size(); i++) { - HybridConditional::shared_ptr conditional = this->at(i); + auto conditional = this->at(i); if (conditional->isDiscrete()) { - auto discrete = conditional->asDiscrete(); + discreteProbs = discreteProbs * (*conditional->asDiscrete()); - // Convert pointer from conditional to factor - auto discreteFactor = - std::dynamic_pointer_cast(discrete); - // Apply prunerFunc to the underlying conditional - DecisionTreeFactor::ADT prunedDiscreteFactor = - discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional)); - - gttic_(HybridBayesNet_MakeConditional); - // Create the new (hybrid) conditional - KeyVector frontals(discrete->frontals().begin(), - discrete->frontals().end()); - auto prunedDiscrete = std::make_shared( - frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor); - conditional = std::make_shared(prunedDiscrete); - gttoc_(HybridBayesNet_MakeConditional); - - // Add it back to the BayesNet - this->at(i) = conditional; + Ordering conditional_keys(conditional->frontals()); + discrete_frontals += conditional_keys; + discrete_factor_idxs.push_back(i); } } + const DecisionTreeFactor prunedDiscreteProbs = + discreteProbs.prune(maxNrLeaves); + gttoc_(HybridBayesNet_PruneDiscreteConditionals); + + // Eliminate joint probability back into conditionals + gttic_(HybridBayesNet_UpdateDiscreteConditionals); + DiscreteFactorGraph dfg{prunedDiscreteProbs}; + DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); + + // Assign pruned discrete conditionals back at the correct indices. + for (size_t i = 0; i < discrete_factor_idxs.size(); i++) { + size_t idx = discrete_factor_idxs.at(i); + this->at(idx) = std::make_shared(dbn->at(i)); + } + gttoc_(HybridBayesNet_UpdateDiscreteConditionals); + + return prunedDiscreteProbs; } /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { - // Get the joint distribution of only the discrete keys - gttic_(HybridBayesNet_PruneDiscreteConditionals); - DiscreteConditional::shared_ptr discreteConditionals = - this->discreteConditionals(); - const DecisionTreeFactor prunedDiscreteProbs = - discreteConditionals->prune(maxNrLeaves); - gttoc_(HybridBayesNet_PruneDiscreteConditionals); + DecisionTreeFactor prunedDiscreteProbs = + this->pruneDiscreteConditionals(maxNrLeaves); - gttic_(HybridBayesNet_UpdateDiscreteConditionals); - this->updateDiscreteConditionals(prunedDiscreteProbs); - gttoc_(HybridBayesNet_UpdateDiscreteConditionals); - - /* To Prune, we visitWith every leaf in the GaussianMixture. + /* To prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. * diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 19e88d754..e71cfe9b4 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /** - * @brief Get all the discrete conditionals as a decision tree factor. - * - * @return DiscreteConditional::shared_ptr - */ - DiscreteConditional::shared_ptr discreteConditionals() const; - /** * @brief Sample from an incomplete BayesNet, given missing variables. * @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { private: /** - * @brief Update the discrete conditionals with the pruned versions. + * @brief Prune all the discrete conditionals. * - * @param prunedDiscreteProbs + * @param maxNrLeaves */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); + DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f25675a55..1dfcbd6b7 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) { auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); // Regression test on pruned logProbability tree - std::vector pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; + std::vector pruned_leaves = {0.0, 32.713418, 0.0, 31.735823}; AlgebraicDecisionTree expected_pruned(discrete_keys, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); @@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) { logProbability += posterior->at(4)->asDiscrete()->logProbability(hybridValues); + // Regression double density = exp(logProbability); - EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(density, + 1.6078460548731697 * actualTree(discrete_values), 1e-6); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); @@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { EXPECT_LONGS_EQUAL(7, posterior->size()); size_t maxNrLeaves = 3; - auto discreteConditionals = posterior->discreteConditionals(); + DiscreteConditional discreteConditionals; + for (auto&& conditional : *posterior) { + if (conditional->isDiscrete()) { + discreteConditionals = + discreteConditionals * (*conditional->asDiscrete()); + } + } const DecisionTreeFactor::shared_ptr prunedDecisionTree = std::make_shared( - discreteConditionals->prune(maxNrLeaves)); + discreteConditionals.prune(maxNrLeaves)); EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, prunedDecisionTree->nrLeaves()); - auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); + // regression + DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; + DecisionTreeFactor::ADT potentials( + dkeys, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); + DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); // Prune! posterior->prune(maxNrLeaves); - // Functor to verify values against the original_discrete_conditionals + // Functor to verify values against the expected_discrete_conditionals auto checker = [&](const Assignment& assignment, double probability) -> double { // typecast so we can use this to get probability value @@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { if (prunedDecisionTree->operator()(choices) == 0) { EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); } else { - EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, + EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability, 1e-9); } return 0.0; From 2f3fcff9160a623eb6021c013240a5965e6876ec Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 17 Jul 2023 12:47:19 -0400 Subject: [PATCH 15/17] fix tests --- gtsam/discrete/tests/testDecisionTree.cpp | 33 +---------------------- gtsam/hybrid/tests/testMixtureFactor.cpp | 4 +-- 2 files changed, 3 insertions(+), 34 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 395915068..efa7a1c44 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -531,38 +532,6 @@ TEST(DecisionTree, ApplyWithAssignment) { EXPECT_LONGS_EQUAL(5, count); } -/* ************************************************************************** */ -// Test number of assignments. -TEST(DecisionTree, NrAssignments2) { - using gtsam::symbol_shorthand::M; - - std::vector probs = {0, 0, 1, 2}; - - /* Create the decision tree - Choice(m1) - 0 Leaf 0.000000 - 1 Choice(m0) - 1 0 Leaf 1.000000 - 1 1 Leaf 2.000000 - */ - DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; - DecisionTree dt1(keys, probs); - EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); - - /* Create the DecisionTree - Choice(m1) - 0 Choice(m0) - 0 0 Leaf 0.000000 - 0 1 Leaf 1.000000 - 1 Choice(m0) - 1 0 Leaf 0.000000 - 1 1 Leaf 2.000000 - */ - DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; - DecisionTree dt2(keys2, probs); - EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 67a7fd8ae..79188b909 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) { R"(Hybrid [x1 x2; 1] MixtureFactor Choice(1) - 0 Leaf Nonlinear factor on 2 keys - 1 Leaf Nonlinear factor on 2 keys + 0 Leaf [1]Nonlinear factor on 2 keys + 1 Leaf [1]Nonlinear factor on 2 keys )"; EXPECT(assert_print_equal(expected, mixtureFactor)); } From 103641c51ad89738fd90c414fede5de957c1e7c2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 17 Jul 2023 13:31:19 -0400 Subject: [PATCH 16/17] include fix --- gtsam/hybrid/HybridFactor.h | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 13d5c2cba..afd1c8032 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include From cb084b3c16bc532f82bff256038ddd771697d12b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 18 Jul 2023 10:21:56 -0400 Subject: [PATCH 17/17] Fix spacing in DecisionTree::print --- gtsam/discrete/DecisionTree-inl.h | 2 +- .../hybrid/tests/testGaussianMixtureFactor.cpp | 4 ++-- .../tests/testHybridNonlinearFactorGraph.cpp | 18 +++++++++--------- gtsam/hybrid/tests/testMixtureFactor.cpp | 4 ++-- gtsam/linear/GaussianConditional.cpp | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 19e3e2887..f998a6065 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -93,7 +93,7 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf [" << nrAssignments() << "]" + std::cout << s << " Leaf [" << nrAssignments() << "] " << valueFormatter(constant_) << std::endl; } diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 5207e9372..549223497 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) { std::string expected = R"(Hybrid [x1 x2; 1]{ Choice(1) - 0 Leaf [1]: + 0 Leaf [1] : A[x1] = [ 0; 0 @@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) { b = [ 0 0 ] No noise model - 1 Leaf [1]: + 1 Leaf [1] : A[x1] = [ 0; 0 diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 7bcaf1762..12506b8af 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -492,7 +492,7 @@ factor 0: factor 1: Hybrid [x0 x1; m0]{ Choice(m0) - 0 Leaf [1]: + 0 Leaf [1] : A[x0] = [ -1 ] @@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{ b = [ -1 ] No noise model - 1 Leaf [1]: + 1 Leaf [1] : A[x0] = [ -1 ] @@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{ factor 2: Hybrid [x1 x2; m1]{ Choice(m1) - 0 Leaf [1]: + 0 Leaf [1] : A[x1] = [ -1 ] @@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{ b = [ -1 ] No noise model - 1 Leaf [1]: + 1 Leaf [1] : A[x1] = [ -1 ] @@ -550,16 +550,16 @@ factor 4: b = [ -10 ] No noise model factor 5: P( m0 ): - Leaf [2] 0.5 + Leaf [2] 0.5 factor 6: P( m1 | m0 ): Choice(m1) 0 Choice(m0) - 0 0 Leaf [1]0.33333333 - 0 1 Leaf [1] 0.6 + 0 0 Leaf [1] 0.33333333 + 0 1 Leaf [1] 0.6 1 Choice(m0) - 1 0 Leaf [1]0.66666667 - 1 1 Leaf [1] 0.4 + 1 0 Leaf [1] 0.66666667 + 1 1 Leaf [1] 0.4 )"; EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 79188b909..03fdccff2 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) { R"(Hybrid [x1 x2; 1] MixtureFactor Choice(1) - 0 Leaf [1]Nonlinear factor on 2 keys - 1 Leaf [1]Nonlinear factor on 2 keys + 0 Leaf [1] Nonlinear factor on 2 keys + 1 Leaf [1] Nonlinear factor on 2 keys )"; EXPECT(assert_print_equal(expected, mixtureFactor)); } diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 188c31abe..0112835aa 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -99,7 +99,7 @@ namespace gtsam { /* ************************************************************************ */ void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const { - cout << s << " p("; + cout << (s.empty() ? "" : s + " ") << "p("; for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { cout << formatter(*it) << (nrFrontals() > 1 ? " " : ""); }