Merge pull request #1955 from borglab/hybrid-custom-discrete

release/4.3a0
Varun Agrawal 2025-01-03 13:41:34 -05:00 committed by GitHub
commit 73f98d8cf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 81 additions and 26 deletions

View File

@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <algorithm> #include <algorithm>
#include <cassert>
#include <random> #include <random>
#include <set> #include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <cassert>
using namespace std; using namespace std;
using std::pair; using std::pair;

View File

@ -120,13 +120,7 @@ namespace gtsam {
static DecisionTreeFactor DiscreteProduct( static DecisionTreeFactor DiscreteProduct(
const DiscreteFactorGraph& factors) { const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product(); DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize); gttic_(DiscreteNormalize);
@ -229,13 +223,7 @@ namespace gtsam {
DecisionTreeFactor product = DiscreteProduct(factors); DecisionTreeFactor product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
@ -245,14 +233,8 @@ namespace gtsam {
sum->keys().end()); sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional = auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
return {conditional, sum}; return {conditional, sum};
} }

View File

@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials); return std::make_shared<TableFactor>(discreteKeys, potentials);
} }
/**
* @brief Multiply all the `factors` using the machinery of the TableFactor.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = product * TableFactor(*dtf);
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> static DiscreteFactorGraph CollectDiscreteFactors(
discreteElimination(const HybridGaussianFactorGraph &factors, const HybridGaussianFactorGraph &factors) {
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &f : factors) { for (auto &f : factors) {
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) { if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df); dfg.push_back(df);
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys. // Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute a discrete factor from the remaining error. // In this case, compute a discrete factor from the remaining error.
@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} }
} }
return dfg;
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete); gttic_(EliminateDiscrete);
#endif #endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE. // Check if separator is empty.
auto result = EliminateDiscrete(dfg, frontalKeys); // This is the same as checking if the number of frontal variables
// is the same as the number of variables in the DiscreteFactorGraph.
// If the separator is empty, we have a clique of all the discrete variables
// so we can use the TableFactor for efficiency.
if (frontalKeys.size() == dfg.keys().size()) {
// Get product factor
TableFactor product = TableProduct(dfg);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete); gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
auto conditional = std::make_shared<DiscreteConditional>(
frontalKeys.size(), product.toDecisionTreeFactor());
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
return {std::make_shared<HybridConditional>(result.first), result.second}; TableFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
return {std::make_shared<HybridConditional>(conditional), sum};
} else {
// Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second};
}
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;