Merge pull request #1955 from borglab/hybrid-custom-discrete
commit
73f98d8cf3
|
|
@ -24,13 +24,13 @@
|
|||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
using namespace std;
|
||||
using std::pair;
|
||||
|
|
|
|||
|
|
@ -120,13 +120,7 @@ namespace gtsam {
|
|||
static DecisionTreeFactor DiscreteProduct(
|
||||
const DiscreteFactorGraph& factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteProduct);
|
||||
#endif
|
||||
DecisionTreeFactor product = factors.product();
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteProduct);
|
||||
#endif
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteNormalize);
|
||||
|
|
@ -229,13 +223,7 @@ namespace gtsam {
|
|||
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteSum);
|
||||
#endif
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteSum);
|
||||
#endif
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
|
|
@ -245,14 +233,8 @@ namespace gtsam {
|
|||
sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteToDiscreteConditional);
|
||||
#endif
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteToDiscreteConditional);
|
||||
#endif
|
||||
|
||||
return {conditional, sum};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
|||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Multiply all the `factors` using the machinery of the TableFactor.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return TableFactor
|
||||
*/
|
||||
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteProduct);
|
||||
#endif
|
||||
TableFactor product;
|
||||
for (auto &&factor : factors) {
|
||||
if (factor) {
|
||||
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||
product = product * (*f);
|
||||
} else if (auto dtf =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
product = product * TableFactor(*dtf);
|
||||
}
|
||||
}
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteProduct);
|
||||
#endif
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteNormalize);
|
||||
#endif
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto denominator = product.max(product.size());
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*denominator);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteNormalize);
|
||||
#endif
|
||||
|
||||
return product;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
static DiscreteFactorGraph CollectDiscreteFactors(
|
||||
const HybridGaussianFactorGraph &factors) {
|
||||
DiscreteFactorGraph dfg;
|
||||
|
||||
for (auto &f : factors) {
|
||||
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
dfg.push_back(df);
|
||||
|
||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||
// In this case, compute a discrete factor from the remaining error.
|
||||
|
|
@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
}
|
||||
|
||||
return dfg;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscrete);
|
||||
#endif
|
||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
// Check if separator is empty.
|
||||
// This is the same as checking if the number of frontal variables
|
||||
// is the same as the number of variables in the DiscreteFactorGraph.
|
||||
// If the separator is empty, we have a clique of all the discrete variables
|
||||
// so we can use the TableFactor for efficiency.
|
||||
if (frontalKeys.size() == dfg.keys().size()) {
|
||||
// Get product factor
|
||||
TableFactor product = TableProduct(dfg);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscrete);
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
auto conditional = std::make_shared<DiscreteConditional>(
|
||||
frontalKeys.size(), product.toDecisionTreeFactor());
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscrete);
|
||||
#endif
|
||||
|
||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||
|
||||
} else {
|
||||
// Perform sum-product.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue