Merge pull request #1955 from borglab/hybrid-custom-discrete
commit
73f98d8cf3
|
|
@ -24,13 +24,13 @@
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using std::pair;
|
using std::pair;
|
||||||
|
|
|
||||||
|
|
@ -120,13 +120,7 @@ namespace gtsam {
|
||||||
static DecisionTreeFactor DiscreteProduct(
|
static DecisionTreeFactor DiscreteProduct(
|
||||||
const DiscreteFactorGraph& factors) {
|
const DiscreteFactorGraph& factors) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(DiscreteProduct);
|
|
||||||
#endif
|
|
||||||
DecisionTreeFactor product = factors.product();
|
DecisionTreeFactor product = factors.product();
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(DiscreteProduct);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(DiscreteNormalize);
|
gttic_(DiscreteNormalize);
|
||||||
|
|
@ -229,13 +223,7 @@ namespace gtsam {
|
||||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(EliminateDiscreteSum);
|
|
||||||
#endif
|
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(EliminateDiscreteSum);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
Ordering orderedKeys;
|
Ordering orderedKeys;
|
||||||
|
|
@ -245,14 +233,8 @@ namespace gtsam {
|
||||||
sum->keys().end());
|
sum->keys().end());
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(EliminateDiscreteToDiscreteConditional);
|
|
||||||
#endif
|
|
||||||
auto conditional =
|
auto conditional =
|
||||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(EliminateDiscreteToDiscreteConditional);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return {conditional, sum};
|
return {conditional, sum};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Multiply all the `factors` using the machinery of the TableFactor.
|
||||||
|
*
|
||||||
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
|
* @return TableFactor
|
||||||
|
*/
|
||||||
|
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
||||||
|
// PRODUCT: multiply all factors
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteProduct);
|
||||||
|
#endif
|
||||||
|
TableFactor product;
|
||||||
|
for (auto &&factor : factors) {
|
||||||
|
if (factor) {
|
||||||
|
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||||
|
product = product * (*f);
|
||||||
|
} else if (auto dtf =
|
||||||
|
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
|
product = product * TableFactor(*dtf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteProduct);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DiscreteNormalize);
|
||||||
|
#endif
|
||||||
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
|
auto denominator = product.max(product.size());
|
||||||
|
// Normalize the product factor to prevent underflow.
|
||||||
|
product = product / (*denominator);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(DiscreteNormalize);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return product;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static DiscreteFactorGraph CollectDiscreteFactors(
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
const HybridGaussianFactorGraph &factors) {
|
||||||
const Ordering &frontalKeys) {
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
dfg.push_back(df);
|
dfg.push_back(df);
|
||||||
|
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute a discrete factor from the remaining error.
|
// In this case, compute a discrete factor from the remaining error.
|
||||||
|
|
@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return dfg;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
const Ordering &frontalKeys) {
|
||||||
|
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(EliminateDiscrete);
|
gttic_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
// Check if separator is empty.
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
// This is the same as checking if the number of frontal variables
|
||||||
|
// is the same as the number of variables in the DiscreteFactorGraph.
|
||||||
|
// If the separator is empty, we have a clique of all the discrete variables
|
||||||
|
// so we can use the TableFactor for efficiency.
|
||||||
|
if (frontalKeys.size() == dfg.keys().size()) {
|
||||||
|
// Get product factor
|
||||||
|
TableFactor product = TableProduct(dfg);
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
auto conditional = std::make_shared<DiscreteConditional>(
|
||||||
|
frontalKeys.size(), product.toDecisionTreeFactor());
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscrete);
|
gttoc_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Perform sum-product.
|
||||||
|
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue