throw in optimize
							parent
							
								
									a1467c5e84
								
							
						
					
					
						commit
						3d4d750151
					
				|  | @ -26,48 +26,48 @@ namespace gtsam { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors, | ||||
|                                      const KeySet &newFactorKeys) { | ||||
|                                      const KeySet &continuousKeys) { | ||||
|   // Get all the discrete keys from the factors
 | ||||
|   KeySet allDiscrete = factors.discreteKeySet(); | ||||
| 
 | ||||
|   // Create KeyVector with continuous keys followed by discrete keys.
 | ||||
|   KeyVector newKeysDiscreteLast; | ||||
|   KeyVector lastKeys; | ||||
| 
 | ||||
|   // Insert continuous keys first.
 | ||||
|   for (auto &k : newFactorKeys) { | ||||
|   for (auto &k : continuousKeys) { | ||||
|     if (!allDiscrete.exists(k)) { | ||||
|       newKeysDiscreteLast.push_back(k); | ||||
|       lastKeys.push_back(k); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Insert discrete keys at the end
 | ||||
|   std::copy(allDiscrete.begin(), allDiscrete.end(), | ||||
|             std::back_inserter(newKeysDiscreteLast)); | ||||
|             std::back_inserter(lastKeys)); | ||||
| 
 | ||||
|   const VariableIndex index(factors); | ||||
| 
 | ||||
|   // Get an ordering where the new keys are eliminated last
 | ||||
|   Ordering ordering = Ordering::ColamdConstrainedLast( | ||||
|       index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), | ||||
|       true); | ||||
|       index, KeyVector(lastKeys.begin(), lastKeys.end()), true); | ||||
|   return ordering; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| void HybridSmoother::update(const HybridGaussianFactorGraph &graph, | ||||
| void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors, | ||||
|                             std::optional<size_t> maxNrLeaves, | ||||
|                             const std::optional<Ordering> given_ordering) { | ||||
|   const KeySet originalNewFactorKeys = newFactors.keys(); | ||||
| #ifdef DEBUG_SMOOTHER | ||||
|   std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size() | ||||
|             << std::endl; | ||||
|   std::cout << "newFactors size: " << graph.size() << std::endl; | ||||
|   std::cout << "newFactors size: " << newFactors.size() << std::endl; | ||||
| #endif | ||||
|   HybridGaussianFactorGraph updatedGraph; | ||||
|   // Add the necessary conditionals from the previous timestep(s).
 | ||||
|   std::tie(updatedGraph, hybridBayesNet_) = | ||||
|       addConditionals(graph, hybridBayesNet_); | ||||
|       addConditionals(newFactors, hybridBayesNet_); | ||||
| #ifdef DEBUG_SMOOTHER | ||||
|   // print size of graph, updatedGraph, hybridBayesNet_
 | ||||
|   // print size of newFactors, updatedGraph, hybridBayesNet_
 | ||||
|   std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl; | ||||
|   std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size() | ||||
|             << std::endl; | ||||
|  | @ -79,11 +79,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, | |||
|   // If no ordering provided, then we compute one
 | ||||
|   if (!given_ordering.has_value()) { | ||||
|     // Get the keys from the new factors
 | ||||
|     const KeySet newFactorKeys = graph.keys(); | ||||
|     const KeySet continuousKeysToInclude;// = newFactors.keys();
 | ||||
| 
 | ||||
|     // Since updatedGraph now has all the connected conditionals,
 | ||||
|     // we can get the correct ordering.
 | ||||
|     ordering = this->getOrdering(updatedGraph, newFactorKeys); | ||||
|     ordering = this->getOrdering(updatedGraph, continuousKeysToInclude); | ||||
|   } else { | ||||
|     ordering = *given_ordering; | ||||
|   } | ||||
|  | @ -140,12 +140,15 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| std::pair<HybridGaussianFactorGraph, HybridBayesNet> | ||||
| HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, | ||||
| HybridSmoother::addConditionals(const HybridGaussianFactorGraph &newFactors, | ||||
|                                 const HybridBayesNet &hybridBayesNet) const { | ||||
|   HybridGaussianFactorGraph graph(originalGraph); | ||||
|   HybridGaussianFactorGraph graph(newFactors); | ||||
|   HybridBayesNet updatedHybridBayesNet(hybridBayesNet); | ||||
| 
 | ||||
|   KeySet factorKeys = graph.keys(); | ||||
|   KeySet involvedKeys = newFactors.keys(); | ||||
|   auto involved = [&involvedKeys](const Key &key) { | ||||
|     return involvedKeys.find(key) != involvedKeys.end(); | ||||
|   }; | ||||
| 
 | ||||
|   // If hybridBayesNet is not empty,
 | ||||
|   // it means we have conditionals to add to the factor graph.
 | ||||
|  | @ -167,12 +170,11 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, | |||
|       auto conditional = hybridBayesNet.at(i); | ||||
| 
 | ||||
|       for (auto &key : conditional->frontals()) { | ||||
|         if (std::find(factorKeys.begin(), factorKeys.end(), key) != | ||||
|             factorKeys.end()) { | ||||
|           // Add the conditional parents to factorKeys
 | ||||
|         if (involved(key)) { | ||||
|           // Add the conditional parents to involvedKeys
 | ||||
|           // so we add those conditionals too.
 | ||||
|           for (auto &&parentKey : conditional->parents()) { | ||||
|             factorKeys.insert(parentKey); | ||||
|             involvedKeys.insert(parentKey); | ||||
|           } | ||||
|           // Break so we don't add parents twice.
 | ||||
|           break; | ||||
|  | @ -180,15 +182,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, | |||
|       } | ||||
|     } | ||||
| #ifdef DEBUG_SMOOTHER | ||||
|     PrintKeySet(factorKeys); | ||||
|     PrintKeySet(involvedKeys); | ||||
| #endif | ||||
| 
 | ||||
|     for (size_t i = 0; i < hybridBayesNet.size(); i++) { | ||||
|       auto conditional = hybridBayesNet.at(i); | ||||
| 
 | ||||
|       for (auto &key : conditional->frontals()) { | ||||
|         if (std::find(factorKeys.begin(), factorKeys.end(), key) != | ||||
|             factorKeys.end()) { | ||||
|         if (involved(key)) { | ||||
|           newConditionals.push_back(conditional); | ||||
| 
 | ||||
|           // Remove the conditional from the updated Bayes net
 | ||||
|  | @ -218,4 +219,21 @@ const HybridBayesNet &HybridSmoother::hybridBayesNet() const { | |||
|   return hybridBayesNet_; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| HybridValues HybridSmoother::optimize() const { | ||||
|   // Solve for the MPE
 | ||||
|   DiscreteValues mpe = hybridBayesNet_.mpe(); | ||||
| 
 | ||||
|   // Add fixed values to the MPE.
 | ||||
|   mpe.insert(fixedValues_); | ||||
| 
 | ||||
|   // Given the MPE, compute the optimal continuous values.
 | ||||
|   GaussianBayesNet gbn = hybridBayesNet_.choose(mpe); | ||||
|   const VectorValues continuous = gbn.optimize(); | ||||
|   if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) { | ||||
|     throw std::runtime_error("At least one nullptr factor in hybridBayesNet_"); | ||||
|   } | ||||
|   return HybridValues(continuous, mpe); | ||||
| } | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -108,16 +108,7 @@ class GTSAM_EXPORT HybridSmoother { | |||
|   const HybridBayesNet& hybridBayesNet() const; | ||||
| 
 | ||||
|   /// Optimize the hybrid Bayes Net, taking into accound fixed values.
 | ||||
|   HybridValues optimize() const { | ||||
|     // Solve for the MPE
 | ||||
|     DiscreteValues mpe = hybridBayesNet_.mpe(); | ||||
| 
 | ||||
|     // Add fixed values to the MPE.
 | ||||
|     mpe.insert(fixedValues_); | ||||
| 
 | ||||
|     // Given the MPE, compute the optimal continuous values.
 | ||||
|     return HybridValues(hybridBayesNet_.optimize(mpe), mpe); | ||||
|   } | ||||
|   HybridValues optimize() const; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue