improve dead mode removal by checking for empty discrete joints and adding the marginals for future factors
							parent
							
								
									938ae06031
								
							
						
					
					
						commit
						7ca7e4549e
					
				|  | @ -58,15 +58,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, | |||
|     joint = joint * (*conditional); | ||||
|   } | ||||
| 
 | ||||
|   // Create the result starting with the pruned joint.
 | ||||
|   // Initialize the resulting HybridBayesNet.
 | ||||
|   HybridBayesNet result; | ||||
|   result.emplace_shared<DiscreteConditional>(joint); | ||||
|   // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
 | ||||
|   result.back()->asDiscrete()->prune(maxNrLeaves); | ||||
| 
 | ||||
|   // Get pruned discrete probabilities so
 | ||||
|   // we can prune HybridGaussianConditionals.
 | ||||
|   DiscreteConditional pruned = *result.back()->asDiscrete(); | ||||
|   // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
 | ||||
|   DiscreteConditional pruned = joint; | ||||
|   joint.prune(maxNrLeaves); | ||||
| 
 | ||||
|   DiscreteValues deadModesValues; | ||||
|   if (removeDeadModes) { | ||||
|  | @ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, | |||
|     } | ||||
| 
 | ||||
|     // Remove the modes (imperative)
 | ||||
|     result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); | ||||
|     pruned = *result.back()->asDiscrete(); | ||||
|     pruned.removeDiscreteModes(deadModesValues); | ||||
| 
 | ||||
|     /*
 | ||||
|       If the pruned discrete conditional has any keys left, | ||||
|       we add it to the HybridBayesNet. | ||||
|       If not, it means it is an orphan so we don't add this pruned joint, | ||||
|       and instead add only the marginals below. | ||||
|     */ | ||||
|     if (pruned.keys().size() > 0) { | ||||
|       result.emplace_shared<DiscreteConditional>(pruned); | ||||
|     } | ||||
| 
 | ||||
|     // Add the marginals for future factors
 | ||||
|     for (auto &&[key, _] : deadModesValues) { | ||||
|       result.push_back( | ||||
|           std::dynamic_pointer_cast<DiscreteConditional>(marginals(key))); | ||||
|     } | ||||
| 
 | ||||
|   } else { | ||||
|     result.emplace_shared<DiscreteConditional>(pruned); | ||||
|   } | ||||
| 
 | ||||
|   /* To prune, we visitWith every leaf in the HybridGaussianConditional.
 | ||||
|  | @ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const { | |||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return discrete_fg.optimize(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue