Merge branch 'hybrid-smoother' into city10000
						commit
						5cee0a2d34
					
				|  | @ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { | |||
| // TODO(Frank): This can be quite expensive *unless* the factors have already
 | ||||
| // been pruned before. Another, possibly faster approach is branch and bound
 | ||||
| // search to find the K-best leaves and then create a single pruned conditional.
 | ||||
| HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, | ||||
|                                      bool removeDeadModes) const { | ||||
| HybridBayesNet HybridBayesNet::prune( | ||||
|     size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { | ||||
|   // Collect all the discrete conditionals. Could be small if already pruned.
 | ||||
|   const DiscreteBayesNet marginal = discreteMarginal(); | ||||
| 
 | ||||
|  | @ -58,25 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, | |||
|     joint = joint * (*conditional); | ||||
|   } | ||||
| 
 | ||||
|   // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
 | ||||
|   joint.prune(maxNrLeaves); | ||||
| 
 | ||||
|   // Create the result starting with the pruned joint.
 | ||||
|   // Initialize the resulting HybridBayesNet.
 | ||||
|   HybridBayesNet result; | ||||
|   result.emplace_shared<DiscreteConditional>(joint); | ||||
| 
 | ||||
|   // Get pruned discrete probabilities so
 | ||||
|   // we can prune HybridGaussianConditionals.
 | ||||
|   DiscreteConditional pruned = *result.back()->asDiscrete(); | ||||
|   // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
 | ||||
|   DiscreteConditional pruned = joint; | ||||
|   pruned.prune(maxNrLeaves); | ||||
| 
 | ||||
|   DiscreteValues deadModesValues; | ||||
|   if (removeDeadModes) { | ||||
|   if (deadModeThreshold.has_value()) { | ||||
|     DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); | ||||
|     for (auto dkey : pruned.discreteKeys()) { | ||||
|       Vector probabilities = marginals.marginalProbabilities(dkey); | ||||
| 
 | ||||
|       int index = -1; | ||||
|       auto threshold = (probabilities.array() > 0.99); | ||||
|       auto threshold = (probabilities.array() > *deadModeThreshold); | ||||
|       // If atleast 1 value is non-zero, then we can find the index
 | ||||
|       // Else if all are zero, index would be set to 0 which is incorrect
 | ||||
|       if (!threshold.isZero()) { | ||||
|  | @ -89,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, | |||
|     } | ||||
| 
 | ||||
|     // Remove the modes (imperative)
 | ||||
|     result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); | ||||
|     pruned = *result.back()->asDiscrete(); | ||||
|     pruned.removeDiscreteModes(deadModesValues); | ||||
| 
 | ||||
|     /*
 | ||||
|       If the pruned discrete conditional has any keys left, | ||||
|       we add it to the HybridBayesNet. | ||||
|       If not, it means it is an orphan so we don't add this pruned joint, | ||||
|       and instead add only the marginals below. | ||||
|     */ | ||||
|     if (pruned.keys().size() > 0) { | ||||
|       result.emplace_shared<DiscreteConditional>(pruned); | ||||
|     } | ||||
| 
 | ||||
|     // Add the marginals for future factors
 | ||||
|     for (auto &&[key, _] : deadModesValues) { | ||||
|       result.push_back( | ||||
|           std::dynamic_pointer_cast<DiscreteConditional>(marginals(key))); | ||||
|     } | ||||
| 
 | ||||
|   } else { | ||||
|     result.emplace_shared<DiscreteConditional>(pruned); | ||||
|   } | ||||
| 
 | ||||
|   /* To prune, we visitWith every leaf in the HybridGaussianConditional.
 | ||||
|  | @ -101,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, | |||
|    */ | ||||
| 
 | ||||
|   // Go through all the Gaussian conditionals in the Bayes Net and prune them as
 | ||||
|   // per pruned Discrete joint.
 | ||||
|   // per pruned discrete joint.
 | ||||
|   for (auto &&conditional : *this) { | ||||
|     if (auto hgc = conditional->asHybrid()) { | ||||
|       // Prune the hybrid Gaussian conditional!
 | ||||
|       auto prunedHybridGaussianConditional = hgc->prune(pruned); | ||||
| 
 | ||||
|       if (removeDeadModes) { | ||||
|       if (deadModeThreshold.has_value()) { | ||||
|         KeyVector deadKeys, conditionalDiscreteKeys; | ||||
|         for (const auto &kv : deadModesValues) { | ||||
|           deadKeys.push_back(kv.first); | ||||
|  |  | |||
|  | @ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |||
|    * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. | ||||
|    * | ||||
|    * @param maxNrLeaves Continuous values at which to compute the error. | ||||
|    * @param removeDeadModes Flag to enable removal of modes which only have a | ||||
|    * single possible assignment. | ||||
|    * @param deadModeThreshold The threshold to check the mode marginals against. | ||||
|    * If greater than this threshold, the mode gets assigned that value and is | ||||
|    * considered "dead" for hybrid elimination. | ||||
|    * The mode can then be removed since it only has a single possible | ||||
|    * assignment. | ||||
|    * @return A pruned HybridBayesNet | ||||
|    */ | ||||
|   HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; | ||||
|   HybridBayesNet prune( | ||||
|       size_t maxNrLeaves, | ||||
|       const std::optional<double> &deadModeThreshold = {}) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Error method using HybridValues which returns specific error for | ||||
|  |  | |||
|  | @ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, | |||
|   if (maxNrLeaves) { | ||||
|     // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
 | ||||
|     // all the conditionals with the same keys in bayesNetFragment.
 | ||||
|     bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves); | ||||
|     bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); | ||||
|   } | ||||
| 
 | ||||
|   // Add the partial bayes net to the posterior bayes net.
 | ||||
|  | @ -116,6 +116,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, | |||
|             factorKeys.end()) { | ||||
|           newConditionals.push_back(conditional); | ||||
| 
 | ||||
|           // Add the conditional parents to factorKeys
 | ||||
|           // so we add those conditionals too.
 | ||||
|           // NOTE: This assumes we have a structure where
 | ||||
|           // variables depend on those in the future.
 | ||||
|           for (auto &&parentKey : conditional->parents()) { | ||||
|             factorKeys.insert(parentKey); | ||||
|           } | ||||
| 
 | ||||
|           // Remove the conditional from the updated Bayes net
 | ||||
|           auto it = find(updatedHybridBayesNet.begin(), | ||||
|                          updatedHybridBayesNet.end(), conditional); | ||||
|  |  | |||
|  | @ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother { | |||
|   HybridBayesNet hybridBayesNet_; | ||||
|   HybridGaussianFactorGraph remainingFactorGraph_; | ||||
| 
 | ||||
|   /// The threshold above which we make a decision about a mode.
 | ||||
|   std::optional<double> deadModeThreshold_; | ||||
| 
 | ||||
|  public: | ||||
|   /**
 | ||||
|    * @brief Constructor | ||||
|    * | ||||
|    * @param removeDeadModes Flag indicating whether to remove dead modes. | ||||
|    * @param deadModeThreshold The threshold above which a mode gets assigned a | ||||
|    * value and is considered "dead". 0.99 is a good starting value. | ||||
|    */ | ||||
|   HybridSmoother(const std::optional<double> deadModeThreshold = {}) | ||||
|       : deadModeThreshold_(deadModeThreshold) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * Given new factors, perform an incremental update. | ||||
|    * The relevant densities in the `hybridBayesNet` will be added to the input | ||||
|  | @ -53,17 +66,6 @@ class GTSAM_EXPORT HybridSmoother { | |||
|               std::optional<size_t> maxNrLeaves = {}, | ||||
|               const std::optional<Ordering> given_ordering = {}); | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Get an elimination ordering which eliminates continuous and then | ||||
|    * discrete. | ||||
|    * | ||||
|    * Expects `newFactors` to already have the necessary conditionals connected | ||||
|    * to the | ||||
|    * | ||||
|    * @param factors | ||||
|    * @return Ordering | ||||
|    */ | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Get an elimination ordering which eliminates continuous | ||||
|    * and then discrete. | ||||
|  |  | |||
|  | @ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) { | |||
|   HybridValues delta = posterior->optimize(); | ||||
| 
 | ||||
|   // Prune the Bayes net
 | ||||
|   const bool pruneDeadVariables = true; | ||||
|   const double pruneDeadVariables = 0.99; | ||||
|   auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); | ||||
| 
 | ||||
|   // Check that discrete joint only has M0 and not (M0, M1)
 | ||||
|  | @ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) { | |||
|   // Check that hybrid conditionals that only depend on M1
 | ||||
|   // are now Gaussian and not Hybrid
 | ||||
|   EXPECT(prunedBayesNet.at(0)->isDiscrete()); | ||||
|   EXPECT(prunedBayesNet.at(1)->isHybrid()); | ||||
|   EXPECT(prunedBayesNet.at(1)->isDiscrete()); | ||||
|   EXPECT(prunedBayesNet.at(2)->isHybrid()); | ||||
|   // Only P(X2 | X1, M1) depends on M1,
 | ||||
|   // so it gets convert to a Gaussian P(X2 | X1)
 | ||||
|   EXPECT(prunedBayesNet.at(2)->isContinuous()); | ||||
|   EXPECT(prunedBayesNet.at(3)->isHybrid()); | ||||
|   EXPECT(prunedBayesNet.at(3)->isContinuous()); | ||||
|   EXPECT(prunedBayesNet.at(4)->isHybrid()); | ||||
| } | ||||
| 
 | ||||
| /* ****************************************************************************/ | ||||
|  |  | |||
|  | @ -167,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) { | |||
|   EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); | ||||
| } | ||||
| 
 | ||||
| /****************************************************************************/ | ||||
| // Test if dead mode removal works.
 | ||||
| TEST(HybridSmoother, DeadModeRemoval) { | ||||
|   using namespace estimation_fixture; | ||||
| 
 | ||||
|   size_t K = 8; | ||||
| 
 | ||||
|   // Switching example of robot moving in 1D
 | ||||
|   // with given measurements and equal mode priors.
 | ||||
|   HybridNonlinearFactorGraph graph; | ||||
|   Values initial; | ||||
|   Switching switching = InitializeEstimationProblem( | ||||
|       K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); | ||||
| 
 | ||||
|   // Smoother with dead mode removal enabled.
 | ||||
|   HybridSmoother smoother(true); | ||||
| 
 | ||||
|   constexpr size_t maxNrLeaves = 3; | ||||
|   for (size_t k = 1; k < K; k++) { | ||||
|     if (k > 1) graph.push_back(switching.modeChain.at(k - 1));  // Mode chain
 | ||||
|     graph.push_back(switching.binaryFactors.at(k - 1));         // Motion Model
 | ||||
|     graph.push_back(switching.unaryFactors.at(k));              // Measurement
 | ||||
| 
 | ||||
|     initial.insert(X(k), switching.linearizationPoint.at<double>(X(k))); | ||||
| 
 | ||||
|     HybridGaussianFactorGraph linearized = *graph.linearize(initial); | ||||
| 
 | ||||
|     smoother.update(linearized, maxNrLeaves); | ||||
| 
 | ||||
|     // Clear all the factors from the graph
 | ||||
|     graph.resize(0); | ||||
|   } | ||||
| 
 | ||||
|   // Get the continuous delta update as well as
 | ||||
|   // the optimal discrete assignment.
 | ||||
|   HybridValues delta = smoother.hybridBayesNet().optimize(); | ||||
| 
 | ||||
|   // Check discrete assignment
 | ||||
|   DiscreteValues expected_discrete; | ||||
|   for (size_t k = 0; k < K - 1; k++) { | ||||
|     expected_discrete[M(k)] = discrete_seq[k]; | ||||
|   } | ||||
|   EXPECT(assert_equal(expected_discrete, delta.discrete())); | ||||
| 
 | ||||
|   // Update nonlinear solution and verify
 | ||||
|   Values result = initial.retract(delta.continuous()); | ||||
|   Values expected_continuous; | ||||
|   for (size_t k = 0; k < K; k++) { | ||||
|     expected_continuous.insert(X(k), measurements[k]); | ||||
|   } | ||||
|   EXPECT(assert_equal(expected_continuous, result)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue