kill TableProduct in favor of DiscreteFactorGraph::scaledProduct
parent
82dba6322f
commit
9960f2d8dc
|
@ -45,9 +45,16 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
|||
/* ************************************************************************* */
|
||||
DiscreteValues HybridBayesTree::discreteMaxProduct(
|
||||
const DiscreteFactorGraph& dfg) const {
|
||||
TableFactor product = TableProduct(dfg);
|
||||
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||
|
||||
DiscreteValues assignment = TableDistribution(product).argmax();
|
||||
// Check type of product, and get as TableFactor for efficiency.
|
||||
TableFactor p;
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||
p = *tf;
|
||||
} else {
|
||||
p = TableFactor(product->toDecisionTreeFactor());
|
||||
}
|
||||
DiscreteValues assignment = TableDistribution(p).argmax();
|
||||
return assignment;
|
||||
}
|
||||
|
||||
|
|
|
@ -255,43 +255,6 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
|||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteProduct);
|
||||
#endif
|
||||
TableFactor product;
|
||||
for (auto &&factor : factors) {
|
||||
if (factor) {
|
||||
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(factor)) {
|
||||
product = product * dtc->table();
|
||||
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||
product = product * (*f);
|
||||
} else if (auto dtf =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
product = product * TableFactor(*dtf);
|
||||
}
|
||||
}
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteProduct);
|
||||
#endif
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteNormalize);
|
||||
#endif
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto denominator = product.max(product.size());
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / *std::dynamic_pointer_cast<TableFactor>(denominator);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteNormalize);
|
||||
#endif
|
||||
|
||||
return product;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static DiscreteFactorGraph CollectDiscreteFactors(
|
||||
const HybridGaussianFactorGraph &factors) {
|
||||
|
@ -357,17 +320,24 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
// so we can use the TableFactor for efficiency.
|
||||
if (frontalKeys.size() == dfg.keys().size()) {
|
||||
// Get product factor
|
||||
TableFactor product = TableProduct(dfg);
|
||||
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
auto conditional = std::make_shared<TableDistribution>(product);
|
||||
// Check type of product, and get as TableFactor for efficiency.
|
||||
TableFactor p;
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
|
||||
p = *tf;
|
||||
} else {
|
||||
p = TableFactor(product->toDecisionTreeFactor());
|
||||
}
|
||||
auto conditional = std::make_shared<TableDistribution>(p);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
||||
DiscreteFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||
|
||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||
|
||||
|
|
|
@ -271,13 +271,4 @@ template <>
|
|||
struct traits<HybridGaussianFactorGraph>
|
||||
: public Testable<HybridGaussianFactorGraph> {};
|
||||
|
||||
/**
|
||||
* @brief Multiply all the `factors` and normalize the
|
||||
* product to prevent underflow.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return TableFactor
|
||||
*/
|
||||
TableFactor TableProduct(const DiscreteFactorGraph& factors);
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
Loading…
Reference in New Issue