Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
commit
34eb0fce9b
|
|
@ -121,25 +121,13 @@ namespace gtsam {
|
||||||
static DecisionTreeFactor ProductAndNormalize(
|
static DecisionTreeFactor ProductAndNormalize(
|
||||||
const DiscreteFactorGraph& factors) {
|
const DiscreteFactorGraph& factors) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(DiscreteProduct);
|
|
||||||
#endif
|
|
||||||
DecisionTreeFactor product = factors.product();
|
DecisionTreeFactor product = factors.product();
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(DiscreteProduct);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
auto normalizer = product.max(product.size());
|
auto normalizer = product.max(product.size());
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(DiscreteNormalize);
|
|
||||||
#endif
|
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / (*normalizer);
|
product = product / (*normalizer);
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(DiscreteNormalize);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
|
@ -230,13 +218,7 @@ namespace gtsam {
|
||||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(EliminateDiscreteSum);
|
|
||||||
#endif
|
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(EliminateDiscreteSum);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
Ordering orderedKeys;
|
Ordering orderedKeys;
|
||||||
|
|
@ -246,14 +228,8 @@ namespace gtsam {
|
||||||
sum->keys().end());
|
sum->keys().end());
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(EliminateDiscreteToDiscreteConditional);
|
|
||||||
#endif
|
|
||||||
auto conditional =
|
auto conditional =
|
||||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(EliminateDiscreteToDiscreteConditional);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return {conditional, sum};
|
return {conditional, sum};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
|
@ -257,6 +256,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Multiply all the `factors` and normalize the
|
||||||
|
* product to prevent underflow.
|
||||||
|
*
|
||||||
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
|
* @return TableFactor
|
||||||
|
*/
|
||||||
|
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
|
||||||
|
// PRODUCT: multiply all factors
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteProduct);
|
||||||
|
#endif
|
||||||
|
TableFactor product;
|
||||||
|
for (auto &&factor : factors) {
|
||||||
|
if (factor) {
|
||||||
|
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||||
|
product = product * (*f);
|
||||||
|
} else if (auto dtf =
|
||||||
|
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
|
product = TableFactor(product * (*dtf));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteProduct);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
|
auto normalizer = product.max(product.size());
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteNormalize);
|
||||||
|
#endif
|
||||||
|
// Normalize the product factor to prevent underflow.
|
||||||
|
product = product / (*normalizer);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteNormalize);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return product;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
@ -306,13 +347,37 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(EliminateDiscrete);
|
gttic_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
/**** NOTE: This does sum-product. ****/
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
// Get product factor
|
||||||
|
TableFactor product = ProductAndNormalize(dfg);
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscreteSum);
|
||||||
|
#endif
|
||||||
|
// All the discrete variables should form a single clique,
|
||||||
|
// so we can sum out on all the variables as frontals.
|
||||||
|
// This should give an empty separator.
|
||||||
|
Ordering orderedKeys(product.keys());
|
||||||
|
TableFactor::shared_ptr sum = product.sum(orderedKeys);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscreteSum);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
// Finally, get the conditional
|
||||||
|
auto conditional =
|
||||||
|
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscrete);
|
gttoc_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||||
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
||||||
|
|
||||||
// Check that factor is discrete and correct
|
// Check that factor is discrete and correct
|
||||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
|
||||||
CHECK(factor);
|
CHECK(factor);
|
||||||
// regression test
|
// regression test
|
||||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) {
|
||||||
|
|
||||||
// Check the remaining factor for x1
|
// Check the remaining factor for x1
|
||||||
CHECK(factor_x1);
|
CHECK(factor_x1);
|
||||||
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
|
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
|
||||||
CHECK(phi_x1);
|
CHECK(phi_x1);
|
||||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||||
// We can't really check the error of the decision tree factor phi_x1, because
|
// We can't really check the error of the decision tree factor phi_x1, because
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue