Merge pull request #1955 from borglab/hybrid-custom-discrete
						commit
						73f98d8cf3
					
				|  | @ -24,13 +24,13 @@ | |||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <cassert> | ||||
| #include <random> | ||||
| #include <set> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
| #include <cassert> | ||||
| 
 | ||||
| using namespace std; | ||||
| using std::pair; | ||||
|  |  | |||
|  | @ -120,13 +120,7 @@ namespace gtsam { | |||
|   static DecisionTreeFactor DiscreteProduct( | ||||
|       const DiscreteFactorGraph& factors) { | ||||
|     // PRODUCT: multiply all factors
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttic_(DiscreteProduct); | ||||
| #endif | ||||
|     DecisionTreeFactor product = factors.product(); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttoc_(DiscreteProduct); | ||||
| #endif | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttic_(DiscreteNormalize); | ||||
|  | @ -229,13 +223,7 @@ namespace gtsam { | |||
|     DecisionTreeFactor product = DiscreteProduct(factors); | ||||
| 
 | ||||
|     // sum out frontals, this is the factor on the separator
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttic_(EliminateDiscreteSum); | ||||
| #endif | ||||
|     DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttoc_(EliminateDiscreteSum); | ||||
| #endif | ||||
| 
 | ||||
|     // Ordering keys for the conditional so that frontalKeys are really in front
 | ||||
|     Ordering orderedKeys; | ||||
|  | @ -245,14 +233,8 @@ namespace gtsam { | |||
|                        sum->keys().end()); | ||||
| 
 | ||||
|     // now divide product/sum to get conditional
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttic_(EliminateDiscreteToDiscreteConditional); | ||||
| #endif | ||||
|     auto conditional = | ||||
|         std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttoc_(EliminateDiscreteToDiscreteConditional); | ||||
| #endif | ||||
| 
 | ||||
|     return {conditional, sum}; | ||||
|   } | ||||
|  |  | |||
|  | @ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( | |||
|   return std::make_shared<TableFactor>(discreteKeys, potentials); | ||||
| } | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Multiply all the `factors` using the machinery of the TableFactor. | ||||
|  * | ||||
|  * @param factors The factors to multiply as a DiscreteFactorGraph. | ||||
|  * @return TableFactor | ||||
|  */ | ||||
| static TableFactor TableProduct(const DiscreteFactorGraph &factors) { | ||||
|   // PRODUCT: multiply all factors
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(DiscreteProduct); | ||||
| #endif | ||||
|   TableFactor product; | ||||
|   for (auto &&factor : factors) { | ||||
|     if (factor) { | ||||
|       if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) { | ||||
|         product = product * (*f); | ||||
|       } else if (auto dtf = | ||||
|                      std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { | ||||
|         product = product * TableFactor(*dtf); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(DiscreteProduct); | ||||
| #endif | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(DiscreteNormalize); | ||||
| #endif | ||||
|   // Max over all the potentials by pretending all keys are frontal:
 | ||||
|   auto denominator = product.max(product.size()); | ||||
|   // Normalize the product factor to prevent underflow.
 | ||||
|   product = product / (*denominator); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(DiscreteNormalize); | ||||
| #endif | ||||
| 
 | ||||
|   return product; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> | ||||
| discreteElimination(const HybridGaussianFactorGraph &factors, | ||||
|                     const Ordering &frontalKeys) { | ||||
| static DiscreteFactorGraph CollectDiscreteFactors( | ||||
|     const HybridGaussianFactorGraph &factors) { | ||||
|   DiscreteFactorGraph dfg; | ||||
| 
 | ||||
|   for (auto &f : factors) { | ||||
|     if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) { | ||||
|       dfg.push_back(df); | ||||
| 
 | ||||
|     } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { | ||||
|       // Case where we have a HybridGaussianFactor with no continuous keys.
 | ||||
|       // In this case, compute a discrete factor from the remaining error.
 | ||||
|  | @ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors, | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return dfg; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> | ||||
| discreteElimination(const HybridGaussianFactorGraph &factors, | ||||
|                     const Ordering &frontalKeys) { | ||||
|   DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(EliminateDiscrete); | ||||
| #endif | ||||
|   // NOTE: This does sum-product. For max-product, use EliminateForMPE.
 | ||||
|   auto result = EliminateDiscrete(dfg, frontalKeys); | ||||
|   // Check if separator is empty.
 | ||||
|   // This is the same as checking if the number of frontal variables
 | ||||
|   // is the same as the number of variables in the DiscreteFactorGraph.
 | ||||
|   // If the separator is empty, we have a clique of all the discrete variables
 | ||||
|   // so we can use the TableFactor for efficiency.
 | ||||
|   if (frontalKeys.size() == dfg.keys().size()) { | ||||
|     // Get product factor
 | ||||
|     TableFactor product = TableProduct(dfg); | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(EliminateDiscrete); | ||||
|     gttic_(EliminateDiscreteFormDiscreteConditional); | ||||
| #endif | ||||
|     auto conditional = std::make_shared<DiscreteConditional>( | ||||
|         frontalKeys.size(), product.toDecisionTreeFactor()); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttoc_(EliminateDiscreteFormDiscreteConditional); | ||||
| #endif | ||||
| 
 | ||||
|   return {std::make_shared<HybridConditional>(result.first), result.second}; | ||||
|     TableFactor::shared_ptr sum = product.sum(frontalKeys); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttoc_(EliminateDiscrete); | ||||
| #endif | ||||
| 
 | ||||
|     return {std::make_shared<HybridConditional>(conditional), sum}; | ||||
| 
 | ||||
|   } else { | ||||
|     // Perform sum-product.
 | ||||
|     auto result = EliminateDiscrete(dfg, frontalKeys); | ||||
|     return {std::make_shared<HybridConditional>(result.first), result.second}; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
|  |  | |||
|  | @ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { | |||
|     EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue