Merge pull request #1343 from borglab/hybrid/model-selection
						commit
						f0cd78f2c9
					
				|  | @ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { | |||
| /**
 | ||||
|  * @brief Helper function to get the pruner functional. | ||||
|  * | ||||
|  * @param decisionTree The probability decision tree of only discrete keys. | ||||
|  * @return std::function<GaussianConditional::shared_ptr( | ||||
|  * const Assignment<Key> &, const GaussianConditional::shared_ptr &)> | ||||
|  * @param prunedDecisionTree  The prob. decision tree of only discrete keys. | ||||
|  * @param conditional Conditional to prune. Used to get full assignment. | ||||
|  * @return std::function<double(const Assignment<Key> &, double)> | ||||
|  */ | ||||
| std::function<double(const Assignment<Key> &, double)> prunerFunc( | ||||
|     const DecisionTreeFactor &decisionTree, | ||||
|     const DecisionTreeFactor &prunedDecisionTree, | ||||
|     const HybridConditional &conditional) { | ||||
|   // Get the discrete keys as sets for the decision tree
 | ||||
|   // and the Gaussian mixture.
 | ||||
|   auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); | ||||
|   auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); | ||||
|   std::set<DiscreteKey> decisionTreeKeySet = | ||||
|       DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); | ||||
|   std::set<DiscreteKey> conditionalKeySet = | ||||
|       DiscreteKeysAsSet(conditional.discreteKeys()); | ||||
| 
 | ||||
|   auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet]( | ||||
|   auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( | ||||
|                     const Assignment<Key> &choices, | ||||
|                     double probability) -> double { | ||||
|     // typecast so we can use this to get probability value
 | ||||
|  | @ -67,17 +69,44 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc( | |||
|     // Case where the Gaussian mixture has the same
 | ||||
|     // discrete keys as the decision tree.
 | ||||
|     if (conditionalKeySet == decisionTreeKeySet) { | ||||
|       if (decisionTree(values) == 0) { | ||||
|       if (prunedDecisionTree(values) == 0) { | ||||
|         return 0.0; | ||||
|       } else { | ||||
|         return probability; | ||||
|       } | ||||
|     } else { | ||||
|       // Due to branch merging (aka pruning) in DecisionTree, it is possible we
 | ||||
|       // get a `values` which doesn't have the full set of keys.
 | ||||
|       std::set<Key> valuesKeys; | ||||
|       for (auto kvp : values) { | ||||
|         valuesKeys.insert(kvp.first); | ||||
|       } | ||||
|       std::set<Key> conditionalKeys; | ||||
|       for (auto kvp : conditionalKeySet) { | ||||
|         conditionalKeys.insert(kvp.first); | ||||
|       } | ||||
|       // If true, then values is missing some keys
 | ||||
|       if (conditionalKeys != valuesKeys) { | ||||
|         // Get the keys present in conditionalKeys but not in valuesKeys
 | ||||
|         std::vector<Key> missing_keys; | ||||
|         std::set_difference(conditionalKeys.begin(), conditionalKeys.end(), | ||||
|                             valuesKeys.begin(), valuesKeys.end(), | ||||
|                             std::back_inserter(missing_keys)); | ||||
|         // Insert missing keys with a default assignment.
 | ||||
|         for (auto missing_key : missing_keys) { | ||||
|           values[missing_key] = 0; | ||||
|         } | ||||
|       } | ||||
| 
 | ||||
|       // Now we generate the full assignment by enumerating
 | ||||
|       // over all keys in the prunedDecisionTree.
 | ||||
|       // First we find the differing keys
 | ||||
|       std::vector<DiscreteKey> set_diff; | ||||
|       std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), | ||||
|                           conditionalKeySet.begin(), conditionalKeySet.end(), | ||||
|                           std::back_inserter(set_diff)); | ||||
| 
 | ||||
|       // Now enumerate over all assignments of the differing keys
 | ||||
|       const std::vector<DiscreteValues> assignments = | ||||
|           DiscreteValues::CartesianProduct(set_diff); | ||||
|       for (const DiscreteValues &assignment : assignments) { | ||||
|  | @ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc( | |||
| 
 | ||||
|         // If any one of the sub-branches are non-zero,
 | ||||
|         // we need this probability.
 | ||||
|         if (decisionTree(augmented_values) > 0.0) { | ||||
|         if (prunedDecisionTree(augmented_values) > 0.0) { | ||||
|           return probability; | ||||
|         } | ||||
|       } | ||||
|  | @ -99,7 +128,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc( | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // TODO(dellaert): what is this non-const method used for? Abolish it?
 | ||||
| void HybridBayesNet::updateDiscreteConditionals( | ||||
|     const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { | ||||
|   KeyVector prunedTreeKeys = prunedDecisionTree->keys(); | ||||
|  | @ -109,8 +137,6 @@ void HybridBayesNet::updateDiscreteConditionals( | |||
|     HybridConditional::shared_ptr conditional = this->at(i); | ||||
|     if (conditional->isDiscrete()) { | ||||
|       auto discrete = conditional->asDiscrete(); | ||||
|       KeyVector frontals(discrete->frontals().begin(), | ||||
|                          discrete->frontals().end()); | ||||
| 
 | ||||
|       // Apply prunerFunc to the underlying AlgebraicDecisionTree
 | ||||
|       auto discreteTree = | ||||
|  | @ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals( | |||
|           discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); | ||||
| 
 | ||||
|       // Create the new (hybrid) conditional
 | ||||
|       KeyVector frontals(discrete->frontals().begin(), | ||||
|                          discrete->frontals().end()); | ||||
|       auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>( | ||||
|           frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); | ||||
|       conditional = boost::make_shared<HybridConditional>(prunedDiscrete); | ||||
|  | @ -206,7 +234,7 @@ GaussianBayesNet HybridBayesNet::choose( | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| HybridValues HybridBayesNet::optimize() const { | ||||
|   // Solve for the MPE
 | ||||
|   // Collect all the discrete factors to compute MPE
 | ||||
|   DiscreteBayesNet discrete_bn; | ||||
|   for (auto &&conditional : *this) { | ||||
|     if (conditional->isDiscrete()) { | ||||
|  | @ -214,6 +242,7 @@ HybridValues HybridBayesNet::optimize() const { | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Solve for the MPE
 | ||||
|   DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); | ||||
| 
 | ||||
|   // Given the MPE, compute the optimal continuous values.
 | ||||
|  |  | |||
|  | @ -138,7 +138,8 @@ struct HybridAssignmentData { | |||
| 
 | ||||
| /* *************************************************************************
 | ||||
|  */ | ||||
| VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { | ||||
| GaussianBayesTree HybridBayesTree::choose( | ||||
|     const DiscreteValues& assignment) const { | ||||
|   GaussianBayesTree gbt; | ||||
|   HybridAssignmentData rootData(assignment, 0, &gbt); | ||||
|   { | ||||
|  | @ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { | |||
|   } | ||||
| 
 | ||||
|   if (!rootData.isValid()) { | ||||
|     return GaussianBayesTree(); | ||||
|   } | ||||
|   return gbt; | ||||
| } | ||||
| 
 | ||||
| /* *************************************************************************
 | ||||
|  */ | ||||
| VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { | ||||
|   GaussianBayesTree gbt = this->choose(assignment); | ||||
|   // If empty GaussianBayesTree, means a clique is pruned hence invalid
 | ||||
|   if (gbt.size() == 0) { | ||||
|     return VectorValues(); | ||||
|   } | ||||
|   VectorValues result = gbt.optimize(); | ||||
|  |  | |||
|  | @ -24,6 +24,7 @@ | |||
| #include <gtsam/inference/BayesTree.h> | ||||
| #include <gtsam/inference/BayesTreeCliqueBase.h> | ||||
| #include <gtsam/inference/Conditional.h> | ||||
| #include <gtsam/linear/GaussianBayesTree.h> | ||||
| 
 | ||||
| #include <string> | ||||
| 
 | ||||
|  | @ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> { | |||
|   /** Check equality */ | ||||
|   bool equals(const This& other, double tol = 1e-9) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete | ||||
|    * value assignment. | ||||
|    * | ||||
|    * @param assignment The discrete value assignment for the discrete keys. | ||||
|    * @return GaussianBayesTree | ||||
|    */ | ||||
|   GaussianBayesTree choose(const DiscreteValues& assignment) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Optimize the hybrid Bayes tree by computing the MPE for the current | ||||
|    * set of discrete variables and using it to compute the best continuous | ||||
|  |  | |||
|  | @ -261,6 +261,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
|       if (!factor) { | ||||
|         return 0.0;  // If nullptr, return 0.0 probability
 | ||||
|       } else { | ||||
|         // This is the probability q(μ) at the MLE point.
 | ||||
|         double error = | ||||
|             0.5 * std::abs(factor->augmentedInformation().determinant()); | ||||
|         return std::exp(-error); | ||||
|  | @ -396,18 +397,16 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, | |||
|   if (discrete_only) { | ||||
|     // Case 1: we are only dealing with discrete
 | ||||
|     return discreteElimination(factors, frontalKeys); | ||||
|   } else { | ||||
|   } else if (mapFromKeyToDiscreteKey.empty()) { | ||||
|     // Case 2: we are only dealing with continuous
 | ||||
|     if (mapFromKeyToDiscreteKey.empty()) { | ||||
|       return continuousElimination(factors, frontalKeys); | ||||
|     } else { | ||||
|       // Case 3: We are now in the hybrid land!
 | ||||
|     return continuousElimination(factors, frontalKeys); | ||||
|   } else { | ||||
|     // Case 3: We are now in the hybrid land!
 | ||||
| #ifdef HYBRID_TIMING | ||||
|       tictoc_reset_(); | ||||
|     tictoc_reset_(); | ||||
| #endif | ||||
|       return hybridElimination(factors, frontalKeys, continuousSeparator, | ||||
|                                discreteSeparatorSet); | ||||
|     } | ||||
|     return hybridElimination(factors, frontalKeys, continuousSeparator, | ||||
|                              discreteSeparatorSet); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ | |||
| /**
 | ||||
|  * @file   HybridGaussianFactorGraph.h | ||||
|  * @brief  Linearized Hybrid factor graph that uses type erasure | ||||
|  * @author Fan Jiang | ||||
|  * @author Fan Jiang, Varun Agrawal | ||||
|  * @date   Mar 11, 2022 | ||||
|  */ | ||||
| 
 | ||||
|  |  | |||
|  | @ -100,8 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, | |||
| /* ************************************************************************* */ | ||||
| GaussianMixture::shared_ptr HybridSmoother::gaussianMixture( | ||||
|     size_t index) const { | ||||
|   return boost::dynamic_pointer_cast<GaussianMixture>( | ||||
|       hybridBayesNet_.at(index)); | ||||
|   return hybridBayesNet_.atMixture(index); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -104,7 +104,7 @@ class GTSAM_EXPORT HybridValues { | |||
|    * @param j The index with which the value will be associated. */ | ||||
|   void insert(Key j, const Vector& value) { continuous_.insert(j, value); } | ||||
| 
 | ||||
|   // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
 | ||||
|   // TODO(Shangjie)- insert_or_assign() , similar to Values.h
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * Read/write access to the discrete value with key \c j, throws | ||||
|  |  | |||
|  | @ -188,12 +188,14 @@ TEST(HybridBayesNet, Optimize) { | |||
| 
 | ||||
|   HybridValues delta = hybridBayesNet->optimize(); | ||||
| 
 | ||||
|   //TODO(Varun) The expectedAssignment should be 111, not 101
 | ||||
|   DiscreteValues expectedAssignment; | ||||
|   expectedAssignment[M(0)] = 1; | ||||
|   expectedAssignment[M(1)] = 0; | ||||
|   expectedAssignment[M(2)] = 1; | ||||
|   EXPECT(assert_equal(expectedAssignment, delta.discrete())); | ||||
| 
 | ||||
|   //TODO(Varun) This should be all -Vector1::Ones()
 | ||||
|   VectorValues expectedValues; | ||||
|   expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); | ||||
|   expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); | ||||
|  |  | |||
|  | @ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) { | |||
|   EXPECT(assert_equal(expectedValues, delta.continuous())); | ||||
| } | ||||
| 
 | ||||
| /* ****************************************************************************/ | ||||
| // Test for choosing a GaussianBayesTree from a HybridBayesTree.
 | ||||
| TEST(HybridBayesTree, Choose) { | ||||
|   Switching s(4); | ||||
| 
 | ||||
|   HybridGaussianISAM isam; | ||||
|   HybridGaussianFactorGraph graph1; | ||||
| 
 | ||||
|   // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
 | ||||
|   for (size_t i = 1; i < 4; i++) { | ||||
|     graph1.push_back(s.linearizedFactorGraph.at(i)); | ||||
|   } | ||||
| 
 | ||||
|   // Add the Gaussian factors, 1 prior on X(0),
 | ||||
|   // 3 measurements on X(2), X(3), X(4)
 | ||||
|   graph1.push_back(s.linearizedFactorGraph.at(0)); | ||||
|   for (size_t i = 4; i <= 6; i++) { | ||||
|     graph1.push_back(s.linearizedFactorGraph.at(i)); | ||||
|   } | ||||
| 
 | ||||
|   // Add the discrete factors
 | ||||
|   for (size_t i = 7; i <= 9; i++) { | ||||
|     graph1.push_back(s.linearizedFactorGraph.at(i)); | ||||
|   } | ||||
| 
 | ||||
|   isam.update(graph1); | ||||
| 
 | ||||
|   DiscreteValues assignment; | ||||
|   assignment[M(0)] = 1; | ||||
|   assignment[M(1)] = 1; | ||||
|   assignment[M(2)] = 1; | ||||
| 
 | ||||
|   GaussianBayesTree gbt = isam.choose(assignment); | ||||
| 
 | ||||
|   Ordering ordering; | ||||
|   ordering += X(0); | ||||
|   ordering += X(1); | ||||
|   ordering += X(2); | ||||
|   ordering += X(3); | ||||
|   ordering += M(0); | ||||
|   ordering += M(1); | ||||
|   ordering += M(2); | ||||
|    | ||||
|   //TODO(Varun) get segfault if ordering not provided
 | ||||
|   auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering); | ||||
|    | ||||
|   auto expected_gbt = bayesTree->choose(assignment); | ||||
| 
 | ||||
|   EXPECT(assert_equal(expected_gbt, gbt)); | ||||
| } | ||||
| 
 | ||||
| /* ****************************************************************************/ | ||||
| // Test HybridBayesTree serialization.
 | ||||
| TEST(HybridBayesTree, Serialization) { | ||||
|  |  | |||
|  | @ -72,25 +72,44 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors, | |||
| } | ||||
| 
 | ||||
| TEST(HybridEstimation, Full) { | ||||
|   size_t K = 3; | ||||
|   std::vector<double> measurements = {0, 1, 2}; | ||||
|   size_t K = 6; | ||||
|   std::vector<double> measurements = {0, 1, 2, 2, 2, 3}; | ||||
|   // Ground truth discrete seq
 | ||||
|   std::vector<size_t> discrete_seq = {1, 1, 0}; | ||||
|   std::vector<size_t> discrete_seq = {1, 1, 0, 0, 1}; | ||||
|   // Switching example of robot moving in 1D
 | ||||
|   // with given measurements and equal mode priors.
 | ||||
|   Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); | ||||
|   HybridGaussianFactorGraph graph = switching.linearizedFactorGraph; | ||||
| 
 | ||||
|   Ordering hybridOrdering; | ||||
|   hybridOrdering += X(0); | ||||
|   hybridOrdering += X(1); | ||||
|   hybridOrdering += X(2); | ||||
|   hybridOrdering += M(0); | ||||
|   hybridOrdering += M(1); | ||||
|   for (size_t k = 0; k < K; k++) { | ||||
|     hybridOrdering += X(k); | ||||
|   } | ||||
|   for (size_t k = 0; k < K - 1; k++) { | ||||
|     hybridOrdering += M(k); | ||||
|   } | ||||
| 
 | ||||
|   HybridBayesNet::shared_ptr bayesNet = | ||||
|       graph.eliminateSequential(hybridOrdering); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(5, bayesNet->size()); | ||||
|   EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size()); | ||||
| 
 | ||||
|   HybridValues delta = bayesNet->optimize(); | ||||
| 
 | ||||
|   Values initial = switching.linearizationPoint; | ||||
|   Values result = initial.retract(delta.continuous()); | ||||
| 
 | ||||
|   DiscreteValues expected_discrete; | ||||
|   for (size_t k = 0; k < K - 1; k++) { | ||||
|     expected_discrete[M(k)] = discrete_seq[k]; | ||||
|   } | ||||
|   EXPECT(assert_equal(expected_discrete, delta.discrete())); | ||||
| 
 | ||||
|   Values expected_continuous; | ||||
|   for (size_t k = 0; k < K; k++) { | ||||
|     expected_continuous.insert(X(k), measurements[k]); | ||||
|   } | ||||
|   EXPECT(assert_equal(expected_continuous, result)); | ||||
| } | ||||
| 
 | ||||
| /****************************************************************************/ | ||||
|  | @ -102,8 +121,8 @@ TEST(HybridEstimation, Incremental) { | |||
|   // Ground truth discrete seq
 | ||||
|   std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, | ||||
|                                       1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; | ||||
|   // Switching example of robot moving in 1D with given measurements and equal
 | ||||
|   // mode priors.
 | ||||
|   // Switching example of robot moving in 1D
 | ||||
|   // with given measurements and equal mode priors.
 | ||||
|   Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); | ||||
|   HybridSmoother smoother; | ||||
|   HybridNonlinearFactorGraph graph; | ||||
|  | @ -209,13 +228,16 @@ std::vector<size_t> getDiscreteSequence(size_t x) { | |||
| } | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Helper method to get the tree of unnormalized probabilities | ||||
|  * as per the new elimination scheme. | ||||
|  * @brief Helper method to get the tree of | ||||
|  * unnormalized probabilities as per the elimination scheme. | ||||
|  * | ||||
|  * Used as a helper to compute q(\mu | M, Z) which is used by | ||||
|  * both P(X | M, Z) and P(M | Z). | ||||
|  * | ||||
|  * @param graph The HybridGaussianFactorGraph to eliminate. | ||||
|  * @return AlgebraicDecisionTree<Key> | ||||
|  */ | ||||
| AlgebraicDecisionTree<Key> probPrimeTree( | ||||
| AlgebraicDecisionTree<Key> getProbPrimeTree( | ||||
|     const HybridGaussianFactorGraph& graph) { | ||||
|   HybridBayesNet::shared_ptr bayesNet; | ||||
|   HybridGaussianFactorGraph::shared_ptr remainingGraph; | ||||
|  | @ -239,20 +261,19 @@ AlgebraicDecisionTree<Key> probPrimeTree( | |||
|   DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys, | ||||
|                                                          vector_values); | ||||
| 
 | ||||
|   // Get the probPrime tree with the correct leaf probabilities
 | ||||
|   std::vector<double> probPrimes; | ||||
|   for (const DiscreteValues& assignment : assignments) { | ||||
|     double error = 0.0; | ||||
|     VectorValues delta = *delta_tree(assignment); | ||||
|     for (auto factor : graph) { | ||||
|       if (factor->isHybrid()) { | ||||
|         auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor); | ||||
|         error += f->error(delta, assignment); | ||||
| 
 | ||||
|       } else if (factor->isContinuous()) { | ||||
|         auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor); | ||||
|         error += f->inner()->error(delta); | ||||
|       } | ||||
|     // If VectorValues is empty, it means this is a pruned branch.
 | ||||
|     // Set the probPrime to 0.0.
 | ||||
|     if (delta.size() == 0) { | ||||
|       probPrimes.push_back(0.0); | ||||
|       continue; | ||||
|     } | ||||
| 
 | ||||
|     double error = graph.error(delta, assignment); | ||||
|     probPrimes.push_back(exp(-error)); | ||||
|   } | ||||
|   AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); | ||||
|  | @ -274,10 +295,23 @@ TEST(HybridEstimation, Probability) { | |||
|   Switching switching(K, between_sigma, measurement_sigma, measurements, | ||||
|                       "1/1 1/1"); | ||||
|   auto graph = switching.linearizedFactorGraph; | ||||
|   Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); | ||||
| 
 | ||||
|   HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering); | ||||
|   auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3); | ||||
|   // Continuous elimination
 | ||||
|   Ordering continuous_ordering(graph.continuousKeys()); | ||||
|   HybridBayesNet::shared_ptr bayesNet; | ||||
|   HybridGaussianFactorGraph::shared_ptr discreteGraph; | ||||
|   std::tie(bayesNet, discreteGraph) = | ||||
|       graph.eliminatePartialSequential(continuous_ordering); | ||||
| 
 | ||||
|   // Discrete elimination
 | ||||
|   Ordering discrete_ordering(graph.discreteKeys()); | ||||
|   auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering); | ||||
| 
 | ||||
|   // Add the discrete conditionals to make it a full bayes net.
 | ||||
|   for (auto discrete_conditional : *discreteBayesNet) { | ||||
|     bayesNet->add(discrete_conditional); | ||||
|   } | ||||
|   auto discreteConditional = discreteBayesNet->atDiscrete(0); | ||||
| 
 | ||||
|   HybridValues hybrid_values = bayesNet->optimize(); | ||||
| 
 | ||||
|  | @ -310,7 +344,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { | |||
|   Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); | ||||
| 
 | ||||
|   // Get the tree of unnormalized probabilities for each mode sequence.
 | ||||
|   AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph); | ||||
|   AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph); | ||||
| 
 | ||||
|   // Eliminate continuous
 | ||||
|   Ordering continuous_ordering(graph.continuousKeys()); | ||||
|  | @ -326,8 +360,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { | |||
|   DiscreteKeys discrete_keys = last_conditional->discreteKeys(); | ||||
| 
 | ||||
|   Ordering discrete(graph.discreteKeys()); | ||||
|   auto discreteBayesTree = | ||||
|       discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete); | ||||
|   auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(1, discreteBayesTree->size()); | ||||
|   // DiscreteBayesTree should have only 1 clique
 | ||||
|  | @ -345,8 +378,8 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { | |||
|       discreteBayesTree->addClique(clique, discrete_clique); | ||||
| 
 | ||||
|     } else { | ||||
|       // Remove the clique from the children of the parents since it will get
 | ||||
|       // added again in addClique.
 | ||||
|       // Remove the clique from the children of the parents since
 | ||||
|       // it will get added again in addClique.
 | ||||
|       auto clique_it = std::find(clique->parent()->children.begin(), | ||||
|                                  clique->parent()->children.end(), clique); | ||||
|       clique->parent()->children.erase(clique_it); | ||||
|  | @ -392,7 +425,7 @@ static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() { | |||
| } | ||||
| 
 | ||||
| /*********************************************************************************
 | ||||
|   // Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
 | ||||
|   // Create a hybrid linear factor graph f(x0, x1, m0; z0, z1)
 | ||||
|  ********************************************************************************/ | ||||
| static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() { | ||||
|   HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph(); | ||||
|  |  | |||
|  | @ -81,14 +81,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         self.assertEqual(hv.atDiscrete(C(0)), 1) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def tiny(num_measurements: int = 1): | ||||
|         """Create a tiny two variable hybrid model.""" | ||||
|     def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: | ||||
|         """ | ||||
|         Create a tiny two variable hybrid model which represents | ||||
|         the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). | ||||
|         """ | ||||
|         # Create hybrid Bayes net. | ||||
|         bayesNet = gtsam.HybridBayesNet() | ||||
| 
 | ||||
|         # Create mode key: 0 is low-noise, 1 is high-noise. | ||||
|         modeKey = M(0) | ||||
|         mode = (modeKey, 2) | ||||
|         mode = (M(0), 2) | ||||
| 
 | ||||
|         # Create Gaussian mixture Z(0) = X(0) + noise for each measurement. | ||||
|         I = np.eye(1) | ||||
|  | @ -141,14 +143,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         return bayesNet.evaluate(sample) / fg.probPrime( | ||||
|             continuous, sample.discrete()) | ||||
| 
 | ||||
|     def test_tiny2(self): | ||||
|         """Test a tiny two variable hybrid model, with 2 measurements.""" | ||||
|         # Create the Bayes net and sample from it. | ||||
|     def test_ratio(self): | ||||
|         """ | ||||
|         Given a tiny two variable hybrid model, with 2 measurements, | ||||
|         test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n) | ||||
|         and the factor graph P(x, n | z)=P(x | n, z)P(n|z), | ||||
|         both of which represent the same posterior. | ||||
|         """ | ||||
|         # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) | ||||
|         bayesNet = self.tiny(num_measurements=2) | ||||
|         sample = bayesNet.sample() | ||||
|         # Sample from the Bayes net. | ||||
|         sample: gtsam.HybridValues = bayesNet.sample() | ||||
|         # print(sample) | ||||
| 
 | ||||
|         # Create a factor graph from the Bayes net with sampled measurements. | ||||
|         # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` | ||||
|         # and thus represents the same joint probability as the Bayes net. | ||||
|         fg = HybridGaussianFactorGraph() | ||||
|         for i in range(2): | ||||
|             conditional = bayesNet.atMixture(i) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue