Merge branch 'hybrid-custom-discrete' into discrete-table-conditional

release/4.3a0
Varun Agrawal 2024-12-31 15:05:27 -05:00
commit 34eb0fce9b
3 changed files with 72 additions and 31 deletions

View File

@ -121,25 +121,13 @@ namespace gtsam {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
@ -230,13 +218,7 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors);
// sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
@ -246,14 +228,8 @@ namespace gtsam {
sum->keys().end());
// now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
return {conditional, sum};
}

View File

@ -20,7 +20,6 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
@ -257,6 +256,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}
/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = TableFactor(product * (*dtf));
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
@ -306,13 +347,37 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
/**** NOTE: This does sum-product. ****/
// Get product factor
TableFactor product = ProductAndNormalize(dfg);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
// All the discrete variables should form a single clique,
// so we can sum out on all the variables as frontals.
// This should give an empty separator.
Ordering orderedKeys(product.keys());
TableFactor::shared_ptr sum = product.sum(orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
// Finally, get the conditional
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
return {std::make_shared<HybridConditional>(result.first), result.second};
return {std::make_shared<HybridConditional>(conditional), sum};
}
/* ************************************************************************ */

View File

@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
CHECK(factor);
// regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
}
/* ************************************************************************* */
@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) {
// Check the remaining factor for x1
CHECK(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because