restrict method
							parent
							
								
									8746b15a4a
								
							
						
					
					
						commit
						4d1a8e5057
					
				|  | @ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { | |||
|   return std::exp(logProbability(values)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| HybridConditional::shared_ptr HybridConditional::restrict( | ||||
|     const DiscreteValues &discreteValues) const { | ||||
|   if (auto gc = asGaussian()) { | ||||
|     return std::make_shared<HybridConditional>(gc); | ||||
|   } else if (auto dc = asDiscrete()) { | ||||
|     return std::make_shared<HybridConditional>(dc); | ||||
|   }; | ||||
| 
 | ||||
|   auto hgc = asHybrid(); | ||||
|   if (!hgc) | ||||
|     throw std::runtime_error( | ||||
|         "HybridConditional::restrict: conditional type not handled"); | ||||
| 
 | ||||
|   // Case 1: Fully determined, return corresponding Gaussian conditional
 | ||||
|   auto parentValues = discreteValues.filter(discreteKeys_); | ||||
|   if (parentValues.size() == discreteKeys_.size()) { | ||||
|     return std::make_shared<HybridConditional>(hgc->choose(parentValues)); | ||||
|   } | ||||
| 
 | ||||
|   // Case 2: Some live parents remain, build a new tree
 | ||||
|   auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); | ||||
|   if (!unspecifiedParentKeys.empty()) { | ||||
|     auto newTree = hgc->factors(); | ||||
|     for (const auto &[key, value] : parentValues) { | ||||
|       newTree = newTree.choose(key, value); | ||||
|     } | ||||
|     return std::make_shared<HybridConditional>( | ||||
|         std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys, | ||||
|                                                     newTree)); | ||||
|   } | ||||
| 
 | ||||
|   // Case 3: No changes needed, return original
 | ||||
|   return std::make_shared<HybridConditional>(hgc); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional | |||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Return a HybridConditional by choosing branches based on the given discrete | ||||
|    * values. If all discrete parents are specified, return a HybridConditional | ||||
|    * which is just a GaussianConditional. If this conditional is *not* a hybrid | ||||
|    * conditional, just return that. | ||||
|    */ | ||||
|   shared_ptr restrict(const DiscreteValues& discreteValues) const; | ||||
| 
 | ||||
|   /// @}
 | ||||
| 
 | ||||
|  private: | ||||
|  |  | |||
|  | @ -316,57 +316,34 @@ auto choose(auto tree, const DiscreteValues &discreteValues) { | |||
|   return tree; | ||||
| } | ||||
| 
 | ||||
| /**
 | ||||
|  * Return a HybridConditional by choosing branches based on the given discrete | ||||
|  * values. If all discrete parents are specified, return a HybridConditional | ||||
|  * which is just a GaussianConditional. | ||||
| /* *************************************************************************
 | ||||
|  * This test verifies the behavior of the restrict method in different | ||||
|  * scenarios: | ||||
|  * - When no restrictions are applied. | ||||
|  * - When one parent is restricted. | ||||
|  * - When two parents are restricted. | ||||
|  * - When the restriction results in a Gaussian conditional. | ||||
|  */ | ||||
| HybridConditional::shared_ptr choose( | ||||
|     const HybridGaussianConditional::shared_ptr &self, | ||||
|     const DiscreteValues &discreteValues) { | ||||
|   auto parentValues = discreteValues.filter(self->discreteKeys()); | ||||
|   auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); | ||||
| TEST(HybridGaussianConditional, Restrict) { | ||||
|   // Create a HybridConditional with two discrete parents P(z0|m0,m1)
 | ||||
|   const auto hc = | ||||
|       std::make_shared<HybridConditional>(two_mode_measurement::hgc); | ||||
| 
 | ||||
|   // Case 1: Fully determined, return corresponding Gaussian conditional
 | ||||
|   if (parentValues.size() == self->discreteKeys().size()) { | ||||
|     return std::make_shared<HybridConditional>(self->choose(parentValues)); | ||||
|   } | ||||
| 
 | ||||
|   // Case 2: Some live parents remain, build a new tree
 | ||||
|   if (!unspecifiedParentKeys.empty()) { | ||||
|     auto newTree = self->factors(); | ||||
|     for (const auto &[key, value] : parentValues) { | ||||
|       newTree = newTree.choose(key, value); | ||||
|     } | ||||
|     return std::make_shared<HybridConditional>( | ||||
|         std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys, | ||||
|                                                     newTree)); | ||||
|   } | ||||
| 
 | ||||
|   // Case 3: No changes needed, return original
 | ||||
|   return std::make_shared<HybridConditional>(self); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Test the pruning and dead-mode removal.
 | ||||
| TEST(HybridGaussianConditional, PrunePlus) { | ||||
|   using two_mode_measurement::hgc;  // two discrete parents
 | ||||
| 
 | ||||
|   const HybridConditional::shared_ptr same = choose(hgc, {}); | ||||
|   const HybridConditional::shared_ptr same = hc->restrict({}); | ||||
|   EXPECT(same->isHybrid()); | ||||
|   EXPECT(same->asHybrid()->nrComponents() == 4); | ||||
| 
 | ||||
|   const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); | ||||
|   const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}}); | ||||
|   EXPECT(oneParent->isHybrid()); | ||||
|   EXPECT(oneParent->asHybrid()->nrComponents() == 2); | ||||
| 
 | ||||
|   const HybridConditional::shared_ptr oneParent2 = | ||||
|       choose(hgc, {{M(7), 0}, {M(1), 0}}); | ||||
|       hc->restrict({{M(7), 0}, {M(1), 0}}); | ||||
|   EXPECT(oneParent2->isHybrid()); | ||||
|   EXPECT(oneParent2->asHybrid()->nrComponents() == 2); | ||||
| 
 | ||||
|   const HybridConditional::shared_ptr gaussian = | ||||
|       choose(hgc, {{M(1), 0}, {M(2), 1}}); | ||||
|       hc->restrict({{M(1), 0}, {M(2), 1}}); | ||||
|   EXPECT(gaussian->asGaussian()); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue