Merge pull request #1955 from borglab/hybrid-custom-discrete
						commit
						73f98d8cf3
					
				|  | @ -24,13 +24,13 @@ | ||||||
| #include <gtsam/hybrid/HybridValues.h> | #include <gtsam/hybrid/HybridValues.h> | ||||||
| 
 | 
 | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
|  | #include <cassert> | ||||||
| #include <random> | #include <random> | ||||||
| #include <set> | #include <set> | ||||||
| #include <stdexcept> | #include <stdexcept> | ||||||
| #include <string> | #include <string> | ||||||
| #include <utility> | #include <utility> | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <cassert> |  | ||||||
| 
 | 
 | ||||||
| using namespace std; | using namespace std; | ||||||
| using std::pair; | using std::pair; | ||||||
|  |  | ||||||
|  | @ -120,13 +120,7 @@ namespace gtsam { | ||||||
|   static DecisionTreeFactor DiscreteProduct( |   static DecisionTreeFactor DiscreteProduct( | ||||||
|       const DiscreteFactorGraph& factors) { |       const DiscreteFactorGraph& factors) { | ||||||
|     // PRODUCT: multiply all factors
 |     // PRODUCT: multiply all factors
 | ||||||
| #if GTSAM_HYBRID_TIMING |  | ||||||
|     gttic_(DiscreteProduct); |  | ||||||
| #endif |  | ||||||
|     DecisionTreeFactor product = factors.product(); |     DecisionTreeFactor product = factors.product(); | ||||||
| #if GTSAM_HYBRID_TIMING |  | ||||||
|     gttoc_(DiscreteProduct); |  | ||||||
| #endif |  | ||||||
| 
 | 
 | ||||||
| #if GTSAM_HYBRID_TIMING | #if GTSAM_HYBRID_TIMING | ||||||
|     gttic_(DiscreteNormalize); |     gttic_(DiscreteNormalize); | ||||||
|  | @ -229,13 +223,7 @@ namespace gtsam { | ||||||
|     DecisionTreeFactor product = DiscreteProduct(factors); |     DecisionTreeFactor product = DiscreteProduct(factors); | ||||||
| 
 | 
 | ||||||
|     // sum out frontals, this is the factor on the separator
 |     // sum out frontals, this is the factor on the separator
 | ||||||
| #if GTSAM_HYBRID_TIMING |  | ||||||
|     gttic_(EliminateDiscreteSum); |  | ||||||
| #endif |  | ||||||
|     DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); |     DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); | ||||||
| #if GTSAM_HYBRID_TIMING |  | ||||||
|     gttoc_(EliminateDiscreteSum); |  | ||||||
| #endif |  | ||||||
| 
 | 
 | ||||||
|     // Ordering keys for the conditional so that frontalKeys are really in front
 |     // Ordering keys for the conditional so that frontalKeys are really in front
 | ||||||
|     Ordering orderedKeys; |     Ordering orderedKeys; | ||||||
|  | @ -245,14 +233,8 @@ namespace gtsam { | ||||||
|                        sum->keys().end()); |                        sum->keys().end()); | ||||||
| 
 | 
 | ||||||
|     // now divide product/sum to get conditional
 |     // now divide product/sum to get conditional
 | ||||||
| #if GTSAM_HYBRID_TIMING |  | ||||||
|     gttic_(EliminateDiscreteToDiscreteConditional); |  | ||||||
| #endif |  | ||||||
|     auto conditional = |     auto conditional = | ||||||
|         std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); |         std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); | ||||||
| #if GTSAM_HYBRID_TIMING |  | ||||||
|     gttoc_(EliminateDiscreteToDiscreteConditional); |  | ||||||
| #endif |  | ||||||
| 
 | 
 | ||||||
|     return {conditional, sum}; |     return {conditional, sum}; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( | ||||||
|   return std::make_shared<TableFactor>(discreteKeys, potentials); |   return std::make_shared<TableFactor>(discreteKeys, potentials); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /**
 | ||||||
|  |  * @brief Multiply all the `factors` using the machinery of the TableFactor. | ||||||
|  |  * | ||||||
|  |  * @param factors The factors to multiply as a DiscreteFactorGraph. | ||||||
|  |  * @return TableFactor | ||||||
|  |  */ | ||||||
|  | static TableFactor TableProduct(const DiscreteFactorGraph &factors) { | ||||||
|  |   // PRODUCT: multiply all factors
 | ||||||
|  | #if GTSAM_HYBRID_TIMING | ||||||
|  |   gttic_(DiscreteProduct); | ||||||
|  | #endif | ||||||
|  |   TableFactor product; | ||||||
|  |   for (auto &&factor : factors) { | ||||||
|  |     if (factor) { | ||||||
|  |       if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) { | ||||||
|  |         product = product * (*f); | ||||||
|  |       } else if (auto dtf = | ||||||
|  |                      std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { | ||||||
|  |         product = product * TableFactor(*dtf); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | #if GTSAM_HYBRID_TIMING | ||||||
|  |   gttoc_(DiscreteProduct); | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | #if GTSAM_HYBRID_TIMING | ||||||
|  |   gttic_(DiscreteNormalize); | ||||||
|  | #endif | ||||||
|  |   // Max over all the potentials by pretending all keys are frontal:
 | ||||||
|  |   auto denominator = product.max(product.size()); | ||||||
|  |   // Normalize the product factor to prevent underflow.
 | ||||||
|  |   product = product / (*denominator); | ||||||
|  | #if GTSAM_HYBRID_TIMING | ||||||
|  |   gttoc_(DiscreteNormalize); | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |   return product; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* ************************************************************************ */ | /* ************************************************************************ */ | ||||||
| static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> | static DiscreteFactorGraph CollectDiscreteFactors( | ||||||
| discreteElimination(const HybridGaussianFactorGraph &factors, |     const HybridGaussianFactorGraph &factors) { | ||||||
|                     const Ordering &frontalKeys) { |  | ||||||
|   DiscreteFactorGraph dfg; |   DiscreteFactorGraph dfg; | ||||||
| 
 | 
 | ||||||
|   for (auto &f : factors) { |   for (auto &f : factors) { | ||||||
|     if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) { |     if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) { | ||||||
|       dfg.push_back(df); |       dfg.push_back(df); | ||||||
|  | 
 | ||||||
|     } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { |     } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { | ||||||
|       // Case where we have a HybridGaussianFactor with no continuous keys.
 |       // Case where we have a HybridGaussianFactor with no continuous keys.
 | ||||||
|       // In this case, compute a discrete factor from the remaining error.
 |       // In this case, compute a discrete factor from the remaining error.
 | ||||||
|  | @ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors, | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   return dfg; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************ */ | ||||||
|  | static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> | ||||||
|  | discreteElimination(const HybridGaussianFactorGraph &factors, | ||||||
|  |                     const Ordering &frontalKeys) { | ||||||
|  |   DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); | ||||||
|  | 
 | ||||||
| #if GTSAM_HYBRID_TIMING | #if GTSAM_HYBRID_TIMING | ||||||
|   gttic_(EliminateDiscrete); |   gttic_(EliminateDiscrete); | ||||||
| #endif | #endif | ||||||
|   // NOTE: This does sum-product. For max-product, use EliminateForMPE.
 |   // Check if separator is empty.
 | ||||||
|   auto result = EliminateDiscrete(dfg, frontalKeys); |   // This is the same as checking if the number of frontal variables
 | ||||||
|  |   // is the same as the number of variables in the DiscreteFactorGraph.
 | ||||||
|  |   // If the separator is empty, we have a clique of all the discrete variables
 | ||||||
|  |   // so we can use the TableFactor for efficiency.
 | ||||||
|  |   if (frontalKeys.size() == dfg.keys().size()) { | ||||||
|  |     // Get product factor
 | ||||||
|  |     TableFactor product = TableProduct(dfg); | ||||||
|  | 
 | ||||||
| #if GTSAM_HYBRID_TIMING | #if GTSAM_HYBRID_TIMING | ||||||
|   gttoc_(EliminateDiscrete); |     gttic_(EliminateDiscreteFormDiscreteConditional); | ||||||
|  | #endif | ||||||
|  |     auto conditional = std::make_shared<DiscreteConditional>( | ||||||
|  |         frontalKeys.size(), product.toDecisionTreeFactor()); | ||||||
|  | #if GTSAM_HYBRID_TIMING | ||||||
|  |     gttoc_(EliminateDiscreteFormDiscreteConditional); | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|   return {std::make_shared<HybridConditional>(result.first), result.second}; |     TableFactor::shared_ptr sum = product.sum(frontalKeys); | ||||||
|  | #if GTSAM_HYBRID_TIMING | ||||||
|  |     gttoc_(EliminateDiscrete); | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |     return {std::make_shared<HybridConditional>(conditional), sum}; | ||||||
|  | 
 | ||||||
|  |   } else { | ||||||
|  |     // Perform sum-product.
 | ||||||
|  |     auto result = EliminateDiscrete(dfg, frontalKeys); | ||||||
|  |     return {std::make_shared<HybridConditional>(result.first), result.second}; | ||||||
|  |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************ */ | /* ************************************************************************ */ | ||||||
|  |  | ||||||
|  | @ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { | ||||||
|     EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); |     EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| int main() { | int main() { | ||||||
|   TestResult tr; |   TestResult tr; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue