diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 485abbc37..8e01c0c76 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { /** * @brief Helper function to get the pruner functional. * - * @param decisionTree The probability decision tree of only discrete keys. - * @return std::function &, const GaussianConditional::shared_ptr &)> + * @param prunedDecisionTree The prob. decision tree of only discrete keys. + * @param conditional Conditional to prune. Used to get full assignment. + * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const DecisionTreeFactor &decisionTree, + const DecisionTreeFactor &prunedDecisionTree, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. - auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); - auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); + std::set decisionTreeKeySet = + DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); + std::set conditionalKeySet = + DiscreteKeysAsSet(conditional.discreteKeys()); - auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet]( + auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( const Assignment &choices, double probability) -> double { // typecast so we can use this to get probability value @@ -67,17 +69,44 @@ std::function &, double)> prunerFunc( // Case where the Gaussian mixture has the same // discrete keys as the decision tree. if (conditionalKeySet == decisionTreeKeySet) { - if (decisionTree(values) == 0) { + if (prunedDecisionTree(values) == 0) { return 0.0; } else { return probability; } } else { + // Due to branch merging (aka pruning) in DecisionTree, it is possible we + // get a `values` which doesn't have the full set of keys. + std::set valuesKeys; + for (auto kvp : values) { + valuesKeys.insert(kvp.first); + } + std::set conditionalKeys; + for (auto kvp : conditionalKeySet) { + conditionalKeys.insert(kvp.first); + } + // If true, then values is missing some keys + if (conditionalKeys != valuesKeys) { + // Get the keys present in conditionalKeys but not in valuesKeys + std::vector missing_keys; + std::set_difference(conditionalKeys.begin(), conditionalKeys.end(), + valuesKeys.begin(), valuesKeys.end(), + std::back_inserter(missing_keys)); + // Insert missing keys with a default assignment. + for (auto missing_key : missing_keys) { + values[missing_key] = 0; + } + } + + // Now we generate the full assignment by enumerating + // over all keys in the prunedDecisionTree. + // First we find the differing keys std::vector set_diff; std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), conditionalKeySet.begin(), conditionalKeySet.end(), std::back_inserter(set_diff)); + // Now enumerate over all assignments of the differing keys const std::vector assignments = DiscreteValues::CartesianProduct(set_diff); for (const DiscreteValues &assignment : assignments) { @@ -86,7 +115,7 @@ std::function &, double)> prunerFunc( // If any one of the sub-branches are non-zero, // we need this probability. - if (decisionTree(augmented_values) > 0.0) { + if (prunedDecisionTree(augmented_values) > 0.0) { return probability; } } @@ -99,7 +128,6 @@ std::function &, double)> prunerFunc( } /* ************************************************************************* */ -// TODO(dellaert): what is this non-const method used for? Abolish it? void HybridBayesNet::updateDiscreteConditionals( const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { KeyVector prunedTreeKeys = prunedDecisionTree->keys(); @@ -109,8 +137,6 @@ void HybridBayesNet::updateDiscreteConditionals( HybridConditional::shared_ptr conditional = this->at(i); if (conditional->isDiscrete()) { auto discrete = conditional->asDiscrete(); - KeyVector frontals(discrete->frontals().begin(), - discrete->frontals().end()); // Apply prunerFunc to the underlying AlgebraicDecisionTree auto discreteTree = @@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals( discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); // Create the new (hybrid) conditional + KeyVector frontals(discrete->frontals().begin(), + discrete->frontals().end()); auto prunedDiscrete = boost::make_shared( frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); conditional = boost::make_shared(prunedDiscrete); @@ -206,7 +234,7 @@ GaussianBayesNet HybridBayesNet::choose( /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { - // Solve for the MPE + // Collect all the discrete factors to compute MPE DiscreteBayesNet discrete_bn; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { @@ -214,6 +242,7 @@ HybridValues HybridBayesNet::optimize() const { } } + // Solve for the MPE DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); // Given the MPE, compute the optimal continuous values. diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 8e41f8b94..df2367cb5 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -138,7 +138,8 @@ struct HybridAssignmentData { /* ************************************************************************* */ -VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { +GaussianBayesTree HybridBayesTree::choose( + const DiscreteValues& assignment) const { GaussianBayesTree gbt; HybridAssignmentData rootData(assignment, 0, &gbt); { @@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { } if (!rootData.isValid()) { + return GaussianBayesTree(); + } + return gbt; +} + +/* ************************************************************************* + */ +VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { + GaussianBayesTree gbt = this->choose(assignment); + // If empty GaussianBayesTree, means a clique is pruned hence invalid + if (gbt.size() == 0) { return VectorValues(); } VectorValues result = gbt.optimize(); diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 2d01aab76..628a453a6 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; + /** + * @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete + * value assignment. + * + * @param assignment The discrete value assignment for the discrete keys. + * @return GaussianBayesTree + */ + GaussianBayesTree choose(const DiscreteValues& assignment) const; + /** * @brief Optimize the hybrid Bayes tree by computing the MPE for the current * set of discrete variables and using it to compute the best continuous diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 1a62e7cb8..aac37bc24 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -261,6 +261,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, if (!factor) { return 0.0; // If nullptr, return 0.0 probability } else { + // This is the probability q(μ) at the MLE point. double error = 0.5 * std::abs(factor->augmentedInformation().determinant()); return std::exp(-error); @@ -396,18 +397,16 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, if (discrete_only) { // Case 1: we are only dealing with discrete return discreteElimination(factors, frontalKeys); - } else { + } else if (mapFromKeyToDiscreteKey.empty()) { // Case 2: we are only dealing with continuous - if (mapFromKeyToDiscreteKey.empty()) { - return continuousElimination(factors, frontalKeys); - } else { - // Case 3: We are now in the hybrid land! + return continuousElimination(factors, frontalKeys); + } else { + // Case 3: We are now in the hybrid land! #ifdef HYBRID_TIMING - tictoc_reset_(); + tictoc_reset_(); #endif - return hybridElimination(factors, frontalKeys, continuousSeparator, - discreteSeparatorSet); - } + return hybridElimination(factors, frontalKeys, continuousSeparator, + discreteSeparatorSet); } } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 2f4590d8a..4e22bed7c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -12,7 +12,7 @@ /** * @file HybridGaussianFactorGraph.h * @brief Linearized Hybrid factor graph that uses type erasure - * @author Fan Jiang + * @author Fan Jiang, Varun Agrawal * @date Mar 11, 2022 */ diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 07a7a4e77..ef77a2413 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -100,8 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, /* ************************************************************************* */ GaussianMixture::shared_ptr HybridSmoother::gaussianMixture( size_t index) const { - return boost::dynamic_pointer_cast( - hybridBayesNet_.at(index)); + return hybridBayesNet_.atMixture(index); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index 90decf769..ff896041e 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -104,7 +104,7 @@ class GTSAM_EXPORT HybridValues { * @param j The index with which the value will be associated. */ void insert(Key j, const Vector& value) { continuous_.insert(j, value); } - // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h + // TODO(Shangjie)- insert_or_assign() , similar to Values.h /** * Read/write access to the discrete value with key \c j, throws diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0170ce423..43cee6f74 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -188,12 +188,14 @@ TEST(HybridBayesNet, Optimize) { HybridValues delta = hybridBayesNet->optimize(); + //TODO(Varun) The expectedAssignment should be 111, not 101 DiscreteValues expectedAssignment; expectedAssignment[M(0)] = 1; expectedAssignment[M(1)] = 0; expectedAssignment[M(2)] = 1; EXPECT(assert_equal(expectedAssignment, delta.discrete())); + //TODO(Varun) This should be all -Vector1::Ones() VectorValues expectedValues; expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 3992aa023..b4d049210 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous())); } +/* ****************************************************************************/ +// Test for choosing a GaussianBayesTree from a HybridBayesTree. +TEST(HybridBayesTree, Choose) { + Switching s(4); + + HybridGaussianISAM isam; + HybridGaussianFactorGraph graph1; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(0), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(s.linearizedFactorGraph.at(0)); + for (size_t i = 4; i <= 6; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the discrete factors + for (size_t i = 7; i <= 9; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + isam.update(graph1); + + DiscreteValues assignment; + assignment[M(0)] = 1; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + + GaussianBayesTree gbt = isam.choose(assignment); + + Ordering ordering; + ordering += X(0); + ordering += X(1); + ordering += X(2); + ordering += X(3); + ordering += M(0); + ordering += M(1); + ordering += M(2); + + //TODO(Varun) get segfault if ordering not provided + auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering); + + auto expected_gbt = bayesTree->choose(assignment); + + EXPECT(assert_equal(expected_gbt, gbt)); +} + /* ****************************************************************************/ // Test HybridBayesTree serialization. TEST(HybridBayesTree, Serialization) { diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index dac1f9f54..927f5c047 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -72,25 +72,44 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors, } TEST(HybridEstimation, Full) { - size_t K = 3; - std::vector measurements = {0, 1, 2}; + size_t K = 6; + std::vector measurements = {0, 1, 2, 2, 2, 3}; // Ground truth discrete seq - std::vector discrete_seq = {1, 1, 0}; + std::vector discrete_seq = {1, 1, 0, 0, 1}; // Switching example of robot moving in 1D // with given measurements and equal mode priors. Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); HybridGaussianFactorGraph graph = switching.linearizedFactorGraph; Ordering hybridOrdering; - hybridOrdering += X(0); - hybridOrdering += X(1); - hybridOrdering += X(2); - hybridOrdering += M(0); - hybridOrdering += M(1); + for (size_t k = 0; k < K; k++) { + hybridOrdering += X(k); + } + for (size_t k = 0; k < K - 1; k++) { + hybridOrdering += M(k); + } + HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(hybridOrdering); - EXPECT_LONGS_EQUAL(5, bayesNet->size()); + EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size()); + + HybridValues delta = bayesNet->optimize(); + + Values initial = switching.linearizationPoint; + Values result = initial.retract(delta.continuous()); + + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); } /****************************************************************************/ @@ -102,8 +121,8 @@ TEST(HybridEstimation, Incremental) { // Ground truth discrete seq std::vector discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; - // Switching example of robot moving in 1D with given measurements and equal - // mode priors. + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); HybridSmoother smoother; HybridNonlinearFactorGraph graph; @@ -209,13 +228,16 @@ std::vector getDiscreteSequence(size_t x) { } /** - * @brief Helper method to get the tree of unnormalized probabilities - * as per the new elimination scheme. + * @brief Helper method to get the tree of + * unnormalized probabilities as per the elimination scheme. + * + * Used as a helper to compute q(\mu | M, Z) which is used by + * both P(X | M, Z) and P(M | Z). * * @param graph The HybridGaussianFactorGraph to eliminate. * @return AlgebraicDecisionTree */ -AlgebraicDecisionTree probPrimeTree( +AlgebraicDecisionTree getProbPrimeTree( const HybridGaussianFactorGraph& graph) { HybridBayesNet::shared_ptr bayesNet; HybridGaussianFactorGraph::shared_ptr remainingGraph; @@ -239,20 +261,19 @@ AlgebraicDecisionTree probPrimeTree( DecisionTree delta_tree(discrete_keys, vector_values); + // Get the probPrime tree with the correct leaf probabilities std::vector probPrimes; for (const DiscreteValues& assignment : assignments) { - double error = 0.0; VectorValues delta = *delta_tree(assignment); - for (auto factor : graph) { - if (factor->isHybrid()) { - auto f = boost::static_pointer_cast(factor); - error += f->error(delta, assignment); - } else if (factor->isContinuous()) { - auto f = boost::static_pointer_cast(factor); - error += f->inner()->error(delta); - } + // If VectorValues is empty, it means this is a pruned branch. + // Set the probPrime to 0.0. + if (delta.size() == 0) { + probPrimes.push_back(0.0); + continue; } + + double error = graph.error(delta, assignment); probPrimes.push_back(exp(-error)); } AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); @@ -274,10 +295,23 @@ TEST(HybridEstimation, Probability) { Switching switching(K, between_sigma, measurement_sigma, measurements, "1/1 1/1"); auto graph = switching.linearizedFactorGraph; - Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); - HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering); - auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3); + // Continuous elimination + Ordering continuous_ordering(graph.continuousKeys()); + HybridBayesNet::shared_ptr bayesNet; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + std::tie(bayesNet, discreteGraph) = + graph.eliminatePartialSequential(continuous_ordering); + + // Discrete elimination + Ordering discrete_ordering(graph.discreteKeys()); + auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering); + + // Add the discrete conditionals to make it a full bayes net. + for (auto discrete_conditional : *discreteBayesNet) { + bayesNet->add(discrete_conditional); + } + auto discreteConditional = discreteBayesNet->atDiscrete(0); HybridValues hybrid_values = bayesNet->optimize(); @@ -310,7 +344,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); // Get the tree of unnormalized probabilities for each mode sequence. - AlgebraicDecisionTree expected_probPrimeTree = probPrimeTree(graph); + AlgebraicDecisionTree expected_probPrimeTree = getProbPrimeTree(graph); // Eliminate continuous Ordering continuous_ordering(graph.continuousKeys()); @@ -326,8 +360,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { DiscreteKeys discrete_keys = last_conditional->discreteKeys(); Ordering discrete(graph.discreteKeys()); - auto discreteBayesTree = - discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete); + auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete); EXPECT_LONGS_EQUAL(1, discreteBayesTree->size()); // DiscreteBayesTree should have only 1 clique @@ -345,8 +378,8 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { discreteBayesTree->addClique(clique, discrete_clique); } else { - // Remove the clique from the children of the parents since it will get - // added again in addClique. + // Remove the clique from the children of the parents since + // it will get added again in addClique. auto clique_it = std::find(clique->parent()->children.begin(), clique->parent()->children.end(), clique); clique->parent()->children.erase(clique_it); @@ -392,7 +425,7 @@ static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() { } /********************************************************************************* - // Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1) + // Create a hybrid linear factor graph f(x0, x1, m0; z0, z1) ********************************************************************************/ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() { HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph(); diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 53ff6354e..481617db1 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -81,14 +81,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): self.assertEqual(hv.atDiscrete(C(0)), 1) @staticmethod - def tiny(num_measurements: int = 1): - """Create a tiny two variable hybrid model.""" + def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: + """ + Create a tiny two variable hybrid model which represents + the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). + """ # Create hybrid Bayes net. bayesNet = gtsam.HybridBayesNet() # Create mode key: 0 is low-noise, 1 is high-noise. - modeKey = M(0) - mode = (modeKey, 2) + mode = (M(0), 2) # Create Gaussian mixture Z(0) = X(0) + noise for each measurement. I = np.eye(1) @@ -141,14 +143,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): return bayesNet.evaluate(sample) / fg.probPrime( continuous, sample.discrete()) - def test_tiny2(self): - """Test a tiny two variable hybrid model, with 2 measurements.""" - # Create the Bayes net and sample from it. + def test_ratio(self): + """ + Given a tiny two variable hybrid model, with 2 measurements, + test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n) + and the factor graph P(x, n | z)=P(x | n, z)P(n|z), + both of which represent the same posterior. + """ + # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) bayesNet = self.tiny(num_measurements=2) - sample = bayesNet.sample() + # Sample from the Bayes net. + sample: gtsam.HybridValues = bayesNet.sample() # print(sample) # Create a factor graph from the Bayes net with sampled measurements. + # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` + # and thus represents the same joint probability as the Bayes net. fg = HybridGaussianFactorGraph() for i in range(2): conditional = bayesNet.atMixture(i)