kill TableProduct in favor of DiscreteFactorGraph::scaledProduct
parent
82dba6322f
commit
9960f2d8dc
|
@ -45,9 +45,16 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteValues HybridBayesTree::discreteMaxProduct(
|
DiscreteValues HybridBayesTree::discreteMaxProduct(
|
||||||
const DiscreteFactorGraph& dfg) const {
|
const DiscreteFactorGraph& dfg) const {
|
||||||
TableFactor product = TableProduct(dfg);
|
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||||
|
|
||||||
DiscreteValues assignment = TableDistribution(product).argmax();
|
// Check type of product, and get as TableFactor for efficiency.
|
||||||
|
TableFactor p;
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||||
|
p = *tf;
|
||||||
|
} else {
|
||||||
|
p = TableFactor(product->toDecisionTreeFactor());
|
||||||
|
}
|
||||||
|
DiscreteValues assignment = TableDistribution(p).argmax();
|
||||||
return assignment;
|
return assignment;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -255,43 +255,6 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
|
||||||
// PRODUCT: multiply all factors
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(DiscreteProduct);
|
|
||||||
#endif
|
|
||||||
TableFactor product;
|
|
||||||
for (auto &&factor : factors) {
|
|
||||||
if (factor) {
|
|
||||||
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(factor)) {
|
|
||||||
product = product * dtc->table();
|
|
||||||
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
|
||||||
product = product * (*f);
|
|
||||||
} else if (auto dtf =
|
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
|
||||||
product = product * TableFactor(*dtf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(DiscreteProduct);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(DiscreteNormalize);
|
|
||||||
#endif
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
|
||||||
auto denominator = product.max(product.size());
|
|
||||||
// Normalize the product factor to prevent underflow.
|
|
||||||
product = product / *std::dynamic_pointer_cast<TableFactor>(denominator);
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(DiscreteNormalize);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return product;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static DiscreteFactorGraph CollectDiscreteFactors(
|
static DiscreteFactorGraph CollectDiscreteFactors(
|
||||||
const HybridGaussianFactorGraph &factors) {
|
const HybridGaussianFactorGraph &factors) {
|
||||||
|
@ -357,17 +320,24 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// so we can use the TableFactor for efficiency.
|
// so we can use the TableFactor for efficiency.
|
||||||
if (frontalKeys.size() == dfg.keys().size()) {
|
if (frontalKeys.size() == dfg.keys().size()) {
|
||||||
// Get product factor
|
// Get product factor
|
||||||
TableFactor product = TableProduct(dfg);
|
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||||
#endif
|
#endif
|
||||||
auto conditional = std::make_shared<TableDistribution>(product);
|
// Check type of product, and get as TableFactor for efficiency.
|
||||||
|
TableFactor p;
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||||
|
p = *tf;
|
||||||
|
} else {
|
||||||
|
p = TableFactor(product->toDecisionTreeFactor());
|
||||||
|
}
|
||||||
|
auto conditional = std::make_shared<TableDistribution>(p);
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DiscreteFactor::shared_ptr sum = product.sum(frontalKeys);
|
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||||
|
|
||||||
|
|
|
@ -271,13 +271,4 @@ template <>
|
||||||
struct traits<HybridGaussianFactorGraph>
|
struct traits<HybridGaussianFactorGraph>
|
||||||
: public Testable<HybridGaussianFactorGraph> {};
|
: public Testable<HybridGaussianFactorGraph> {};
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Multiply all the `factors` and normalize the
|
|
||||||
* product to prevent underflow.
|
|
||||||
*
|
|
||||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
|
||||||
* @return TableFactor
|
|
||||||
*/
|
|
||||||
TableFactor TableProduct(const DiscreteFactorGraph& factors);
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue